pollo/lib/session.go

202 lines
4.7 KiB
Go
Raw Normal View History

2024-06-27 21:23:51 -06:00
package lib
import (
2024-06-28 00:35:58 -06:00
"crypto/aes"
"crypto/cipher"
2024-06-27 21:23:51 -06:00
"crypto/rand"
2024-09-01 18:52:23 -06:00
"crypto/sha256"
2024-06-28 00:35:58 -06:00
"encoding/base64"
2024-06-27 21:23:51 -06:00
"encoding/hex"
2024-06-28 00:35:58 -06:00
"encoding/json"
"errors"
"fmt"
2024-06-27 21:23:51 -06:00
"log"
"net/http"
"os"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
)
2024-06-28 00:35:58 -06:00
type SessionData struct {
SessionID string `json:"sessionID"`
UserID string `json:"userId"`
2024-06-28 09:11:03 -06:00
Name string `json:"name"`
2024-06-28 00:35:58 -06:00
Email string `json:"email"`
Roles []string `json:"roles"`
}
2024-06-27 21:23:51 -06:00
func InitSessionMiddleware() echo.MiddlewareFunc {
authSecret := os.Getenv("AUTH_SECRET")
if authSecret == "" {
log.Fatal("AUTH_SECRET environment variable is not set.")
}
store := sessions.NewCookieStore([]byte(authSecret))
return session.Middleware(store)
}
2024-09-01 18:52:23 -06:00
// Returns the first 32 bytes of the SHA-256 hash of the ENCRYPTION_KEY environment variable
func getEncryptionKey() []byte {
key := []byte(os.Getenv("ENCRYPTION_KEY"))
hash := sha256.Sum256(key)
return hash[:32] // Use the first 32 bytes for AES-256
}
2024-06-28 00:35:58 -06:00
// Encrypt data using AES
func encrypt(data []byte) (string, error) {
2024-09-01 18:52:23 -06:00
encryptionKey := getEncryptionKey()
2024-06-28 00:35:58 -06:00
fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey))
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
_, err = rand.Read(nonce)
if err != nil {
return "", err
}
cipherText := gcm.Seal(nonce, nonce, data, nil)
return base64.StdEncoding.EncodeToString(cipherText), nil
}
// decrypt decrypts the data using AES-GCM.
func decrypt(encryptedString string) (string, error) {
2024-09-01 18:52:23 -06:00
encryptionKey := getEncryptionKey()
2024-06-28 00:35:58 -06:00
data, err := base64.StdEncoding.DecodeString(encryptedString)
if err != nil {
return "", err
}
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return "", errors.New("ciphertext too short")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return "", err
}
return string(plaintext), nil
}
// Adjusted SetSessionCookie to include more user info
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) {
// Serialize session data
dataBytes, err := json.Marshal(sessionData)
if err != nil {
log.Fatal("Failed to serialize session data:", err)
}
2024-06-27 21:23:51 -06:00
2024-06-28 00:35:58 -06:00
// Encrypt serialized session data
encryptedData, err := encrypt(dataBytes)
if err != nil {
log.Fatal("Failed to encrypt session data:", err)
}
// Set cookie with encrypted data
2024-06-27 21:23:51 -06:00
http.SetCookie(w, &http.Cookie{
Name: name,
2024-06-28 00:35:58 -06:00
Value: encryptedData,
2024-06-27 21:23:51 -06:00
Path: "/",
HttpOnly: true,
2024-06-28 00:35:58 -06:00
Secure: os.Getenv("DEVMODE") != "true",
2024-06-27 21:23:51 -06:00
SameSite: http.SameSiteStrictMode,
MaxAge: 3600,
})
}
2024-06-28 00:35:58 -06:00
// Adjusted GetSessionCookie to return SessionData
func GetSessionCookie(r *http.Request, name string) (*SessionData, error) {
2024-06-27 21:23:51 -06:00
cookie, err := r.Cookie(name)
if err != nil {
2024-06-28 00:35:58 -06:00
return nil, err
2024-06-27 21:23:51 -06:00
}
2024-06-28 00:35:58 -06:00
// Decrypt the cookie value
decryptedValue, err := decrypt(cookie.Value)
if err != nil {
return nil, err
}
// Deserialize session data
var sessionData SessionData
err = json.Unmarshal([]byte(decryptedValue), &sessionData)
if err != nil {
return nil, err
}
return &sessionData, nil
2024-06-27 21:23:51 -06:00
}
2024-06-27 21:51:00 -06:00
// ClearSessionCookie clears the session cookie from the client's browser
func ClearSessionCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: "/",
HttpOnly: true,
Secure: os.Getenv("DEVMODE") != "true",
SameSite: http.SameSiteStrictMode,
2024-06-28 00:35:58 -06:00
MaxAge: -1, // This will delete the cookie
2024-06-27 21:51:00 -06:00
})
}
2024-06-27 21:23:51 -06:00
// Checks if the user is signed in by checking the session cookie
func IsSignedIn(c echo.Context) bool {
2024-06-28 09:11:03 -06:00
sessionData, err := GetSessionCookie(c.Request(), "session")
2024-06-28 00:35:58 -06:00
if err != nil {
log.Printf("Error retrieving session cookie: %v", err)
2024-06-28 09:11:03 -06:00
return false
2024-06-28 00:35:58 -06:00
}
2024-06-28 09:11:03 -06:00
// Validate the session data by checking if the user exists in the database
user, err := GetUserByID(GetDBPool(), sessionData.UserID)
if err != nil || user == nil {
log.Printf("Session refers to a non-existent user: %v", err)
ClearSessionCookie(c.Response().Writer, "session")
return false
}
return true
2024-06-27 21:23:51 -06:00
}
// GenerateRandomBytes returns securely generated random bytes.
func GenerateRandomBytes(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
// GenerateSessionID generates a new session ID.
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}