Fixed DMs

This commit is contained in:
2025-12-28 01:09:10 -07:00
parent cffbb54a3e
commit 94d7cde6fa
2 changed files with 90 additions and 45 deletions

View File

@@ -48,7 +48,6 @@ func InitDB() {
FOREIGN KEY (room_id) REFERENCES rooms(id), FOREIGN KEY (room_id) REFERENCES rooms(id),
FOREIGN KEY (username) REFERENCES users(username) FOREIGN KEY (username) REFERENCES users(username)
); );
-- Stores the encrypted room key for a specific user for a specific epoch
CREATE TABLE IF NOT EXISTS user_room_keys ( CREATE TABLE IF NOT EXISTS user_room_keys (
username TEXT, username TEXT,
room_id TEXT, room_id TEXT,
@@ -132,6 +131,12 @@ func getRoomIDByName(name string) (string, error) {
return id, err return id, err
} }
func getRoomCreator(roomID string) (string, error) {
var creator string
err := db.QueryRow("SELECT creator FROM rooms WHERE id = ?", roomID).Scan(&creator)
return creator, err
}
func getRoomCurrentEpoch(roomID string) (int, error) { func getRoomCurrentEpoch(roomID string) (int, error) {
var epoch int var epoch int
err := db.QueryRow("SELECT current_epoch FROM rooms WHERE id = ?", roomID).Scan(&epoch) err := db.QueryRow("SELECT current_epoch FROM rooms WHERE id = ?", roomID).Scan(&epoch)
@@ -140,7 +145,6 @@ func getRoomCurrentEpoch(roomID string) (int, error) {
func incrementRoomEpoch(roomID string) (int, error) { func incrementRoomEpoch(roomID string) (int, error) {
var newEpoch int var newEpoch int
// Atomic increment
err := db.QueryRow("UPDATE rooms SET current_epoch = current_epoch + 1 WHERE id = ? RETURNING current_epoch", roomID).Scan(&newEpoch) err := db.QueryRow("UPDATE rooms SET current_epoch = current_epoch + 1 WHERE id = ? RETURNING current_epoch", roomID).Scan(&newEpoch)
return newEpoch, err return newEpoch, err
} }
@@ -175,7 +179,6 @@ func leaveRoom(roomID, username string) error {
} }
func joinRoomMember(roomID, username string) error { func joinRoomMember(roomID, username string) error {
// Check if already joined
var count int var count int
db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count) db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count)
if count > 0 { if count > 0 {

View File

@@ -51,7 +51,6 @@ type model struct {
currentRoomID string currentRoomID string
currentRoomEpoch int currentRoomEpoch int
// We cache keys: Epoch -> Plaintext Key
roomKeyCache map[int][]byte roomKeyCache map[int][]byte
rooms []struct { rooms []struct {
@@ -143,6 +142,11 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.ClearScreen return m, tea.ClearScreen
case tea.KeyMsg: case tea.KeyMsg:
// Clear any displayed errors on user interaction
if m.err != nil {
m.err = nil
}
switch msg.Type { switch msg.Type {
case tea.KeyCtrlC, tea.KeyEsc: case tea.KeyCtrlC, tea.KeyEsc:
return m, tea.Quit return m, tea.Quit
@@ -226,23 +230,21 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
} }
if val == "/leave" { if val == "/leave" {
creator, _ := getRoomCreator(m.currentRoomID)
if creator == m.username {
m.saveMessage("System", "Error: Owners cannot leave. Use /delete to destroy room.")
m.input.Reset()
return m, nil
}
leaveRoom(m.currentRoomID, m.username) leaveRoom(m.currentRoomID, m.username)
m.rotateRoomKey(m.currentRoomID) // Rotate key on leave m.rotateRoomKey(m.currentRoomID)
m.exitChat() m.exitChat()
return m, tea.ClearScreen return m, tea.ClearScreen
} }
if val == "/delete" { if val == "/delete" {
isCreator := false creator, _ := getRoomCreator(m.currentRoomID)
for _, r := range m.rooms { if creator == m.username {
if r.ID == m.currentRoomID {
if r.Creator == m.username {
isCreator = true
}
break
}
}
if isCreator {
deleteRoom(m.currentRoomID) deleteRoom(m.currentRoomID)
m.exitChat() m.exitChat()
return m, tea.ClearScreen return m, tea.ClearScreen
@@ -289,15 +291,10 @@ func (m *model) loadRooms() {
m.rooms = rooms m.rooms = rooms
m.viewport.SetContent("") m.viewport.SetContent("")
m.input.SetValue("") m.input.SetValue("")
m.input.Placeholder = "Enter room # or: /join <name>, /dm <username>, /list" m.input.Placeholder = "Enter room # or: /join <name>, /dm <username>"
} }
func (m *model) handleRoomListInput(text string) { func (m *model) handleRoomListInput(text string) {
if text == "/list" {
m.loadRooms()
return
}
if strings.HasPrefix(text, "/join ") { if strings.HasPrefix(text, "/join ") {
roomName := strings.TrimPrefix(text, "/join ") roomName := strings.TrimPrefix(text, "/join ")
m.handleJoinRoom(roomName) m.handleJoinRoom(roomName)
@@ -324,18 +321,15 @@ func (m *model) handleRoomListInput(text string) {
func (m *model) handleJoinRoom(roomName string) { func (m *model) handleJoinRoom(roomName string) {
existingID, err := getRoomIDByName(roomName) existingID, err := getRoomIDByName(roomName)
if err == nil && existingID != "" { if err == nil && existingID != "" {
// Join existing
if err := joinRoomMember(existingID, m.username); err != nil { if err := joinRoomMember(existingID, m.username); err != nil {
m.err = err m.err = err
return return
} }
// ROTATE KEY so new user gets a key, but doesn't get old keys
m.rotateRoomKey(existingID) m.rotateRoomKey(existingID)
m.enterRoom(existingID, roomName) m.enterRoom(existingID, roomName)
return return
} }
// Create new
roomID := generateRoomID() roomID := generateRoomID()
if err := createRoom(roomID, roomName, m.username, false); err != nil { if err := createRoom(roomID, roomName, m.username, false); err != nil {
m.err = err m.err = err
@@ -345,7 +339,6 @@ func (m *model) handleJoinRoom(roomName string) {
m.err = err m.err = err
return return
} }
// Initial Key
m.rotateRoomKey(roomID) m.rotateRoomKey(roomID)
m.enterRoom(roomID, roomName) m.enterRoom(roomID, roomName)
} }
@@ -382,7 +375,7 @@ func (m *model) rotateRoomKey(roomID string) {
func (m *model) handleSelectUserForDM(username string) { func (m *model) handleSelectUserForDM(username string) {
if username == "/cancel" { if username == "/cancel" {
m.state = 1 m.state = 1
m.input.Placeholder = "Enter room # or: /join <name>, /dm <username>, /list" m.input.Placeholder = "Enter room # or: /join <name>, /dm <username>"
return return
} }
if username == m.username { if username == m.username {
@@ -390,6 +383,40 @@ func (m *model) handleSelectUserForDM(username string) {
return return
} }
// CHECK IF USER EXISTS
identityKey, prekey, prekeySignature, err := getUserKeys(username)
if err != nil {
if err == sql.ErrNoRows {
m.err = fmt.Errorf("user '%s' not found", username)
} else {
m.err = err
}
return
}
if len(identityKey) == 0 {
m.err = fmt.Errorf("user '%s' has no keys set up", username)
return
}
// Pre-calc shared secret for DM
var theirIdentityKey, theirPrekey [32]byte
copy(theirIdentityKey[:], identityKey)
copy(theirPrekey[:], prekey)
ephemeralKey, err := generateIdentityKeyPair()
if err != nil {
m.err = err
return
}
dh1, _ := performDH(m.identityKey.PrivateKey, theirPrekey)
dh2, _ := performDH(ephemeralKey.PrivateKey, theirIdentityKey)
dh3, _ := performDH(ephemeralKey.PrivateKey, theirPrekey)
_ = deriveSharedSecret(dh1, dh2, dh3)
_ = prekeySignature
roomID := generateRoomID() roomID := generateRoomID()
roomName := fmt.Sprintf("DM: %s <-> %s", m.username, username) roomName := fmt.Sprintf("DM: %s <-> %s", m.username, username)
@@ -418,10 +445,17 @@ func (m *model) createNoteToSelf() {
func (m *model) enterRoom(roomID, roomName string) { func (m *model) enterRoom(roomID, roomName string) {
epoch, _ := getRoomCurrentEpoch(roomID) epoch, _ := getRoomCurrentEpoch(roomID)
creator, _ := getRoomCreator(roomID)
m.currentRoomID = roomID m.currentRoomID = roomID
m.currentRoomEpoch = epoch m.currentRoomEpoch = epoch
m.state = 2 m.state = 2
m.input.Placeholder = fmt.Sprintf("[%s] /back, /leave, /delete", roomName)
if creator == m.username {
m.input.Placeholder = fmt.Sprintf("[%s] /back to menu, /delete to destroy", roomName)
} else {
m.input.Placeholder = fmt.Sprintf("[%s] /back to menu, /leave to quit", roomName)
}
m.loadMessages() m.loadMessages()
} }
@@ -442,28 +476,23 @@ func generateRoomID() string {
} }
func (m *model) getEpochKey(epoch int) []byte { func (m *model) getEpochKey(epoch int) []byte {
// Check cache
if key, ok := m.roomKeyCache[epoch]; ok { if key, ok := m.roomKeyCache[epoch]; ok {
return key return key
} }
encKey, err := getUserRoomKey(m.username, m.currentRoomID, epoch) encKey, err := getUserRoomKey(m.username, m.currentRoomID, epoch)
if err != nil || encKey == nil { if err != nil || encKey == nil {
return nil return nil
} }
key, err := DecryptKeyForUser(m.identityKey.PrivateKey, encKey) key, err := DecryptKeyForUser(m.identityKey.PrivateKey, encKey)
if err != nil { if err != nil {
return nil return nil
} }
m.roomKeyCache[epoch] = key m.roomKeyCache[epoch] = key
return key return key
} }
func (m *model) saveMessage(sender, text string) { func (m *model) saveMessage(sender, text string) {
m.currentRoomEpoch, _ = getRoomCurrentEpoch(m.currentRoomID) m.currentRoomEpoch, _ = getRoomCurrentEpoch(m.currentRoomID)
key := m.getEpochKey(m.currentRoomEpoch) key := m.getEpochKey(m.currentRoomEpoch)
if key == nil { if key == nil {
m.err = fmt.Errorf("no key for current epoch") m.err = fmt.Errorf("no key for current epoch")
@@ -555,29 +584,38 @@ No public key authentication provided.
} }
if m.needsRegistration { if m.needsRegistration {
errStr := ""
if m.err != nil {
errStr = "\n" + errStyle.Render(m.err.Error()) + "\n"
}
return fmt.Sprintf(` return fmt.Sprintf(`
%s %s
Welcome! You're a new user. Welcome! You're a new user.
Please choose a username to register. Please choose a username to register.
%s %s
`, sysStyle.Render("SECURE TUI CHAT - REGISTRATION"), m.input.View()) %s
`, sysStyle.Render("SECURE TUI CHAT - REGISTRATION"), errStr, m.input.View())
} }
if m.state == 0 { if m.state == 0 {
errStr := ""
if m.err != nil {
errStr = "\n" + errStyle.Render(m.err.Error()) + "\n"
}
return fmt.Sprintf(` return fmt.Sprintf(`
%s %s
Welcome back, %s. Welcome back, %s.
This environment is encrypted at rest. This environment is encrypted at rest.
Please enter your passphrase to unlock your keys. Please enter your passphrase to unlock your keys.
%s %s
`, sysStyle.Render("SECURE TUI CHAT"), senderStyle.Render(m.username), m.input.View()) %s
`, sysStyle.Render("SECURE TUI CHAT"), senderStyle.Render(m.username), errStr, m.input.View())
} }
if m.state == 1 { if m.state == 1 {
var b strings.Builder var b strings.Builder
b.WriteString(sysStyle.Render("=== YOUR ROOMS ===") + "\n\n") b.WriteString(sysStyle.Render("=== YOUR ROOMS ===") + "\n\n")
if len(m.rooms) == 0 { if len(m.rooms) == 0 {
b.WriteString(sysStyle.Render("No rooms yet.") + "\n") b.WriteString(sysStyle.Render("No rooms yet.") + "\n")
b.WriteString(commandStyle.Render(" /join <name>") + " - Join/Create room\n") b.WriteString(commandStyle.Render(" /join <name>") + " - Join/Create room\n")
@@ -601,7 +639,11 @@ Please enter your passphrase to unlock your keys.
prefix, prefix,
style.Render(room.Name))) style.Render(room.Name)))
} }
b.WriteString("\n" + sysStyle.Render("Commands: ") + commandStyle.Render("/join /dm /list") + "\n") b.WriteString("\n" + sysStyle.Render("Commands: ") + commandStyle.Render("/join /dm") + "\n")
}
if m.err != nil {
b.WriteString("\n" + errStyle.Render(m.err.Error()))
} }
b.WriteString("\n" + m.input.View()) b.WriteString("\n" + m.input.View())
return b.String() return b.String()