diff --git a/api/signin.go b/api/signin.go index e2d4be7..13458c5 100644 --- a/api/signin.go +++ b/api/signin.go @@ -25,13 +25,16 @@ func SignInUserHandler(c echo.Context) error { } // Set the session cookie with the generated session ID - lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{ + err = lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{ SessionID: sessionID, - Name: user.Name, UserID: user.ID, + Name: user.Name, Email: user.Email, Roles: []string{"user"}, }) + if err != nil { + return c.JSON(http.StatusInternalServerError, "Failed to create session") + } // Proceed with login success logic c.Response().Header().Set("HX-Redirect", "/") diff --git a/api/signout.go b/api/signout.go index e614008..c556aa2 100644 --- a/api/signout.go +++ b/api/signout.go @@ -1,6 +1,7 @@ package api import ( + "log" "net/http" "pollo/lib" @@ -8,10 +9,19 @@ import ( ) func SignOutUserHandler(c echo.Context) error { + sessionData, err := lib.GetSessionCookie(c.Request(), "session") + if err == nil && sessionData != nil { + // Delete the session from the database + err = lib.DeleteSession(lib.GetDBPool(), sessionData.SessionID) + if err != nil { + log.Printf("Error deleting session: %v", err) + } + } + // Clear the session cookie lib.ClearSessionCookie(c.Response().Writer, "session") - // Proceed with login success logic + // Proceed with logout success logic c.Response().Header().Set("HX-Redirect", "/") return c.NoContent(http.StatusOK) } diff --git a/lib/crypto.go b/lib/crypto.go new file mode 100644 index 0000000..7b84b46 --- /dev/null +++ b/lib/crypto.go @@ -0,0 +1,85 @@ +package lib + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "fmt" + "os" +) + +// Returns the first 32 bytes of the SHA-256 hash of the ENCRYPTION_KEY environment variable +func GetEncryptionKey() []byte { + key := []byte(os.Getenv("ENCRYPTION_KEY")) + hash := sha256.Sum256(key) + return hash[:32] // Use the first 32 bytes for AES-256 +} + +// Encrypt data using AES +func Encrypt(data []byte) (string, error) { + encryptionKey := GetEncryptionKey() + fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey)) + + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return "", err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, gcm.NonceSize()) + _, err = rand.Read(nonce) + if err != nil { + return "", err + } + cipherText := gcm.Seal(nonce, nonce, data, nil) + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// decrypt decrypts the data using AES-GCM. +func Decrypt(encryptedString string) (string, error) { + encryptionKey := GetEncryptionKey() + + data, err := base64.StdEncoding.DecodeString(encryptedString) + if err != nil { + return "", err + } + + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return "", err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonceSize := gcm.NonceSize() + if len(data) < nonceSize { + return "", errors.New("ciphertext too short") + } + + nonce, ciphertext := data[:nonceSize], data[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", err + } + + return string(plaintext), nil +} + +// GenerateRandomBytes returns securely generated random bytes. +func GenerateRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + _, err := rand.Read(b) + if err != nil { + return nil, err + } + + return b, nil +} diff --git a/lib/db.go b/lib/db.go index a84ec92..5296754 100644 --- a/lib/db.go +++ b/lib/db.go @@ -81,6 +81,19 @@ func InitializeSchema(dbPool *pgxpool.Pool) error { CREATE INDEX IF NOT EXISTS idx_rooms_userid ON rooms(userId); `, }, + { + Version: "6_create_sessions_table", + SQL: ` + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + expires_at TIMESTAMP NOT NULL, + FOREIGN KEY (user_id) REFERENCES users(id) + ); + CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id); + `, + }, } // Ensure the migrations table exists diff --git a/lib/session.go b/lib/session.go index 76a1525..3699bce 100644 --- a/lib/session.go +++ b/lib/session.go @@ -1,20 +1,16 @@ package lib import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha256" - "encoding/base64" + "context" "encoding/hex" "encoding/json" - "errors" - "fmt" "log" "net/http" "os" + "time" "github.com/gorilla/sessions" + "github.com/jackc/pgx/v4/pgxpool" "github.com/labstack/echo-contrib/session" "github.com/labstack/echo/v4" ) @@ -37,80 +33,27 @@ func InitSessionMiddleware() echo.MiddlewareFunc { return session.Middleware(store) } -// Returns the first 32 bytes of the SHA-256 hash of the ENCRYPTION_KEY environment variable -func getEncryptionKey() []byte { - key := []byte(os.Getenv("ENCRYPTION_KEY")) - hash := sha256.Sum256(key) - return hash[:32] // Use the first 32 bytes for AES-256 -} - -// Encrypt data using AES -func encrypt(data []byte) (string, error) { - encryptionKey := getEncryptionKey() - fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey)) - - block, err := aes.NewCipher(encryptionKey) - if err != nil { - return "", err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - nonce := make([]byte, gcm.NonceSize()) - _, err = rand.Read(nonce) - if err != nil { - return "", err - } - cipherText := gcm.Seal(nonce, nonce, data, nil) - return base64.StdEncoding.EncodeToString(cipherText), nil -} - -// decrypt decrypts the data using AES-GCM. -func decrypt(encryptedString string) (string, error) { - encryptionKey := getEncryptionKey() - - data, err := base64.StdEncoding.DecodeString(encryptedString) - if err != nil { - return "", err - } - - block, err := aes.NewCipher(encryptionKey) - if err != nil { - return "", err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - nonceSize := gcm.NonceSize() - if len(data) < nonceSize { - return "", errors.New("ciphertext too short") - } - - nonce, ciphertext := data[:nonceSize], data[nonceSize:] - plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) - if err != nil { - return "", err - } - - return string(plaintext), nil -} - // Adjusted SetSessionCookie to include more user info -func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) { +func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) error { + // Create session in database + expiresAt := time.Now().Add(48 * time.Hour) // Set expiration to 1 hour from now + sessionID, err := CreateSession(GetDBPool(), sessionData.UserID, expiresAt) + if err != nil { + return err + } + + sessionData.SessionID = sessionID + // Serialize session data dataBytes, err := json.Marshal(sessionData) if err != nil { - log.Fatal("Failed to serialize session data:", err) + return err } // Encrypt serialized session data - encryptedData, err := encrypt(dataBytes) + encryptedData, err := Encrypt(dataBytes) if err != nil { - log.Fatal("Failed to encrypt session data:", err) + return err } // Set cookie with encrypted data @@ -121,8 +64,10 @@ func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionDat HttpOnly: true, Secure: os.Getenv("DEVMODE") != "true", SameSite: http.SameSiteStrictMode, - MaxAge: 3600, + MaxAge: 3600, // 1 hour }) + + return nil } // Adjusted GetSessionCookie to return SessionData @@ -133,7 +78,7 @@ func GetSessionCookie(r *http.Request, name string) (*SessionData, error) { } // Decrypt the cookie value - decryptedValue, err := decrypt(cookie.Value) + decryptedValue, err := Decrypt(cookie.Value) if err != nil { return nil, err } @@ -169,10 +114,10 @@ func IsSignedIn(c echo.Context) bool { return false } - // Validate the session data by checking if the user exists in the database - user, err := GetUserByID(GetDBPool(), sessionData.UserID) - if err != nil || user == nil { - log.Printf("Session refers to a non-existent user: %v", err) + // Validate the session in the database + validSessionData, err := ValidateSession(GetDBPool(), sessionData.SessionID) + if err != nil || validSessionData == nil { + log.Printf("Invalid session: %v", err) ClearSessionCookie(c.Response().Writer, "session") return false } @@ -180,17 +125,6 @@ func IsSignedIn(c echo.Context) bool { return true } -// GenerateRandomBytes returns securely generated random bytes. -func GenerateRandomBytes(n int) ([]byte, error) { - b := make([]byte, n) - _, err := rand.Read(b) - if err != nil { - return nil, err - } - - return b, nil -} - // GenerateSessionID generates a new session ID. func GenerateSessionID() (string, error) { bytes, err := GenerateRandomBytes(16) @@ -199,3 +133,44 @@ func GenerateSessionID() (string, error) { } return hex.EncodeToString(bytes), nil } + +func CreateSession(dbPool *pgxpool.Pool, userID string, expiresAt time.Time) (string, error) { + sessionID := GenerateNewID("session") + _, err := dbPool.Exec(context.Background(), + "INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)", + sessionID, userID, expiresAt) + if err != nil { + return "", err + } + return sessionID, nil +} + +func ValidateSession(dbPool *pgxpool.Pool, sessionID string) (*SessionData, error) { + var userID string + var expiresAt time.Time + err := dbPool.QueryRow(context.Background(), + "SELECT user_id, expires_at FROM sessions WHERE id = $1 AND expires_at > NOW()", + sessionID).Scan(&userID, &expiresAt) + if err != nil { + return nil, err + } + + user, err := GetUserByID(dbPool, userID) + if err != nil { + return nil, err + } + + return &SessionData{ + SessionID: sessionID, + UserID: user.ID, + Name: user.Name, + Email: user.Email, + Roles: []string{"user"}, + }, nil +} + +func DeleteSession(dbPool *pgxpool.Pool, sessionID string) error { + _, err := dbPool.Exec(context.Background(), + "DELETE FROM sessions WHERE id = $1", sessionID) + return err +}