Session table work now

This commit is contained in:
Atridad Lahiji 2024-09-16 13:46:25 -06:00
parent 32374cfbb4
commit 632568ecf7
Signed by: atridad
SSH key fingerprint: SHA256:LGomp8Opq0jz+7kbwNcdfTcuaLRb5Nh0k5AchDDb438
5 changed files with 179 additions and 93 deletions

View file

@ -25,13 +25,16 @@ func SignInUserHandler(c echo.Context) error {
}
// Set the session cookie with the generated session ID
lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{
err = lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{
SessionID: sessionID,
Name: user.Name,
UserID: user.ID,
Name: user.Name,
Email: user.Email,
Roles: []string{"user"},
})
if err != nil {
return c.JSON(http.StatusInternalServerError, "Failed to create session")
}
// Proceed with login success logic
c.Response().Header().Set("HX-Redirect", "/")

View file

@ -1,6 +1,7 @@
package api
import (
"log"
"net/http"
"pollo/lib"
@ -8,10 +9,19 @@ import (
)
func SignOutUserHandler(c echo.Context) error {
sessionData, err := lib.GetSessionCookie(c.Request(), "session")
if err == nil && sessionData != nil {
// Delete the session from the database
err = lib.DeleteSession(lib.GetDBPool(), sessionData.SessionID)
if err != nil {
log.Printf("Error deleting session: %v", err)
}
}
// Clear the session cookie
lib.ClearSessionCookie(c.Response().Writer, "session")
// Proceed with login success logic
// Proceed with logout success logic
c.Response().Header().Set("HX-Redirect", "/")
return c.NoContent(http.StatusOK)
}

85
lib/crypto.go Normal file
View file

@ -0,0 +1,85 @@
package lib
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"os"
)
// Returns the first 32 bytes of the SHA-256 hash of the ENCRYPTION_KEY environment variable
func GetEncryptionKey() []byte {
key := []byte(os.Getenv("ENCRYPTION_KEY"))
hash := sha256.Sum256(key)
return hash[:32] // Use the first 32 bytes for AES-256
}
// Encrypt data using AES
func Encrypt(data []byte) (string, error) {
encryptionKey := GetEncryptionKey()
fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey))
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
_, err = rand.Read(nonce)
if err != nil {
return "", err
}
cipherText := gcm.Seal(nonce, nonce, data, nil)
return base64.StdEncoding.EncodeToString(cipherText), nil
}
// decrypt decrypts the data using AES-GCM.
func Decrypt(encryptedString string) (string, error) {
encryptionKey := GetEncryptionKey()
data, err := base64.StdEncoding.DecodeString(encryptedString)
if err != nil {
return "", err
}
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}
// GenerateRandomBytes returns securely generated random bytes.
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}

View file

@ -81,6 +81,19 @@ func InitializeSchema(dbPool *pgxpool.Pool) error {
CREATE INDEX IF NOT EXISTS idx_rooms_userid ON rooms(userId);
`,
},
{
Version: "6_create_sessions_table",
SQL: `
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
expires_at TIMESTAMP NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
`,
},
}
// Ensure the migrations table exists

View file

@ -1,20 +1,16 @@
package lib
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"context"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"os"
"time"
"github.com/gorilla/sessions"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
)
@ -37,80 +33,27 @@ func InitSessionMiddleware() echo.MiddlewareFunc {
return session.Middleware(store)
}
// Returns the first 32 bytes of the SHA-256 hash of the ENCRYPTION_KEY environment variable
func getEncryptionKey() []byte {
key := []byte(os.Getenv("ENCRYPTION_KEY"))
hash := sha256.Sum256(key)
return hash[:32] // Use the first 32 bytes for AES-256
}
// Encrypt data using AES
func encrypt(data []byte) (string, error) {
encryptionKey := getEncryptionKey()
fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey))
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
_, err = rand.Read(nonce)
if err != nil {
return "", err
}
cipherText := gcm.Seal(nonce, nonce, data, nil)
return base64.StdEncoding.EncodeToString(cipherText), nil
}
// decrypt decrypts the data using AES-GCM.
func decrypt(encryptedString string) (string, error) {
encryptionKey := getEncryptionKey()
data, err := base64.StdEncoding.DecodeString(encryptedString)
if err != nil {
return "", err
}
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}
// Adjusted SetSessionCookie to include more user info
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) {
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) error {
// Create session in database
expiresAt := time.Now().Add(48 * time.Hour) // Set expiration to 1 hour from now
sessionID, err := CreateSession(GetDBPool(), sessionData.UserID, expiresAt)
if err != nil {
return err
}
sessionData.SessionID = sessionID
// Serialize session data
dataBytes, err := json.Marshal(sessionData)
if err != nil {
log.Fatal("Failed to serialize session data:", err)
return err
}
// Encrypt serialized session data
encryptedData, err := encrypt(dataBytes)
encryptedData, err := Encrypt(dataBytes)
if err != nil {
log.Fatal("Failed to encrypt session data:", err)
return err
}
// Set cookie with encrypted data
@ -121,8 +64,10 @@ func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionDat
HttpOnly: true,
Secure: os.Getenv("DEVMODE") != "true",
SameSite: http.SameSiteStrictMode,
MaxAge: 3600,
MaxAge: 3600, // 1 hour
})
return nil
}
// Adjusted GetSessionCookie to return SessionData
@ -133,7 +78,7 @@ func GetSessionCookie(r *http.Request, name string) (*SessionData, error) {
}
// Decrypt the cookie value
decryptedValue, err := decrypt(cookie.Value)
decryptedValue, err := Decrypt(cookie.Value)
if err != nil {
return nil, err
}
@ -169,10 +114,10 @@ func IsSignedIn(c echo.Context) bool {
return false
}
// Validate the session data by checking if the user exists in the database
user, err := GetUserByID(GetDBPool(), sessionData.UserID)
if err != nil || user == nil {
log.Printf("Session refers to a non-existent user: %v", err)
// Validate the session in the database
validSessionData, err := ValidateSession(GetDBPool(), sessionData.SessionID)
if err != nil || validSessionData == nil {
log.Printf("Invalid session: %v", err)
ClearSessionCookie(c.Response().Writer, "session")
return false
}
@ -180,17 +125,6 @@ func IsSignedIn(c echo.Context) bool {
return true
}
// GenerateRandomBytes returns securely generated random bytes.
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateSessionID generates a new session ID.
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
@ -199,3 +133,44 @@ func GenerateSessionID() (string, error) {
}
return hex.EncodeToString(bytes), nil
}
func CreateSession(dbPool *pgxpool.Pool, userID string, expiresAt time.Time) (string, error) {
sessionID := GenerateNewID("session")
_, err := dbPool.Exec(context.Background(),
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
sessionID, userID, expiresAt)
if err != nil {
return "", err
}
return sessionID, nil
}
func ValidateSession(dbPool *pgxpool.Pool, sessionID string) (*SessionData, error) {
var userID string
var expiresAt time.Time
err := dbPool.QueryRow(context.Background(),
"SELECT user_id, expires_at FROM sessions WHERE id = $1 AND expires_at > NOW()",
sessionID).Scan(&userID, &expiresAt)
if err != nil {
return nil, err
}
user, err := GetUserByID(dbPool, userID)
if err != nil {
return nil, err
}
return &SessionData{
SessionID: sessionID,
UserID: user.ID,
Name: user.Name,
Email: user.Email,
Roles: []string{"user"},
}, nil
}
func DeleteSession(dbPool *pgxpool.Pool, sessionID string) error {
_, err := dbPool.Exec(context.Background(),
"DELETE FROM sessions WHERE id = $1", sessionID)
return err
}