2024-06-27 21:23:51 -06:00
|
|
|
package lib
|
|
|
|
|
|
|
|
import (
|
2024-06-28 00:35:58 -06:00
|
|
|
"crypto/aes"
|
|
|
|
"crypto/cipher"
|
2024-06-27 21:23:51 -06:00
|
|
|
"crypto/rand"
|
2024-06-28 00:35:58 -06:00
|
|
|
"encoding/base64"
|
2024-06-27 21:23:51 -06:00
|
|
|
"encoding/hex"
|
2024-06-28 00:35:58 -06:00
|
|
|
"encoding/json"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
2024-06-27 21:23:51 -06:00
|
|
|
"log"
|
|
|
|
"net/http"
|
|
|
|
"os"
|
|
|
|
|
|
|
|
"github.com/gorilla/sessions"
|
|
|
|
"github.com/labstack/echo-contrib/session"
|
|
|
|
"github.com/labstack/echo/v4"
|
|
|
|
)
|
|
|
|
|
2024-06-28 00:35:58 -06:00
|
|
|
type SessionData struct {
|
|
|
|
SessionID string `json:"sessionID"`
|
|
|
|
UserID string `json:"userId"`
|
|
|
|
Email string `json:"email"`
|
|
|
|
Roles []string `json:"roles"`
|
|
|
|
}
|
|
|
|
|
2024-06-27 21:23:51 -06:00
|
|
|
func InitSessionMiddleware() echo.MiddlewareFunc {
|
|
|
|
authSecret := os.Getenv("AUTH_SECRET")
|
|
|
|
if authSecret == "" {
|
|
|
|
log.Fatal("AUTH_SECRET environment variable is not set.")
|
|
|
|
}
|
|
|
|
|
|
|
|
store := sessions.NewCookieStore([]byte(authSecret))
|
|
|
|
return session.Middleware(store)
|
|
|
|
}
|
|
|
|
|
2024-06-28 00:35:58 -06:00
|
|
|
// Encrypt data using AES
|
|
|
|
func encrypt(data []byte) (string, error) {
|
|
|
|
encryptionKey := []byte(os.Getenv("ENCRYPTION_KEY"))
|
|
|
|
fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey))
|
|
|
|
|
|
|
|
block, err := aes.NewCipher(encryptionKey)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
nonce := make([]byte, gcm.NonceSize())
|
|
|
|
_, err = rand.Read(nonce)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
cipherText := gcm.Seal(nonce, nonce, data, nil)
|
|
|
|
return base64.StdEncoding.EncodeToString(cipherText), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// decrypt decrypts the data using AES-GCM.
|
|
|
|
func decrypt(encryptedString string) (string, error) {
|
|
|
|
encryptionKey := []byte(os.Getenv("ENCRYPTION_KEY"))
|
|
|
|
|
|
|
|
data, err := base64.StdEncoding.DecodeString(encryptedString)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
block, err := aes.NewCipher(encryptionKey)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
nonceSize := gcm.NonceSize()
|
|
|
|
if len(data) < nonceSize {
|
|
|
|
return "", errors.New("ciphertext too short")
|
|
|
|
}
|
|
|
|
|
|
|
|
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
|
|
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
|
|
|
|
return string(plaintext), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// Adjusted SetSessionCookie to include more user info
|
|
|
|
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) {
|
|
|
|
// Serialize session data
|
|
|
|
dataBytes, err := json.Marshal(sessionData)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal("Failed to serialize session data:", err)
|
|
|
|
}
|
2024-06-27 21:23:51 -06:00
|
|
|
|
2024-06-28 00:35:58 -06:00
|
|
|
// Encrypt serialized session data
|
|
|
|
encryptedData, err := encrypt(dataBytes)
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal("Failed to encrypt session data:", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set cookie with encrypted data
|
2024-06-27 21:23:51 -06:00
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
|
|
Name: name,
|
2024-06-28 00:35:58 -06:00
|
|
|
Value: encryptedData,
|
2024-06-27 21:23:51 -06:00
|
|
|
Path: "/",
|
|
|
|
HttpOnly: true,
|
2024-06-28 00:35:58 -06:00
|
|
|
Secure: os.Getenv("DEVMODE") != "true",
|
2024-06-27 21:23:51 -06:00
|
|
|
SameSite: http.SameSiteStrictMode,
|
|
|
|
MaxAge: 3600,
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2024-06-28 00:35:58 -06:00
|
|
|
// Adjusted GetSessionCookie to return SessionData
|
|
|
|
func GetSessionCookie(r *http.Request, name string) (*SessionData, error) {
|
2024-06-27 21:23:51 -06:00
|
|
|
cookie, err := r.Cookie(name)
|
|
|
|
if err != nil {
|
2024-06-28 00:35:58 -06:00
|
|
|
return nil, err
|
2024-06-27 21:23:51 -06:00
|
|
|
}
|
2024-06-28 00:35:58 -06:00
|
|
|
|
|
|
|
// Decrypt the cookie value
|
|
|
|
decryptedValue, err := decrypt(cookie.Value)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
// Deserialize session data
|
|
|
|
var sessionData SessionData
|
|
|
|
err = json.Unmarshal([]byte(decryptedValue), &sessionData)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return &sessionData, nil
|
2024-06-27 21:23:51 -06:00
|
|
|
}
|
|
|
|
|
2024-06-27 21:51:00 -06:00
|
|
|
// ClearSessionCookie clears the session cookie from the client's browser
|
|
|
|
func ClearSessionCookie(w http.ResponseWriter, name string) {
|
|
|
|
http.SetCookie(w, &http.Cookie{
|
|
|
|
Name: name,
|
|
|
|
Value: "",
|
|
|
|
Path: "/",
|
|
|
|
HttpOnly: true,
|
|
|
|
Secure: os.Getenv("DEVMODE") != "true",
|
|
|
|
SameSite: http.SameSiteStrictMode,
|
2024-06-28 00:35:58 -06:00
|
|
|
MaxAge: -1, // This will delete the cookie
|
2024-06-27 21:51:00 -06:00
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2024-06-27 21:23:51 -06:00
|
|
|
// Checks if the user is signed in by checking the session cookie
|
|
|
|
func IsSignedIn(c echo.Context) bool {
|
2024-06-28 00:35:58 -06:00
|
|
|
_, err := GetSessionCookie(c.Request(), "session")
|
|
|
|
if err != nil {
|
|
|
|
// Log the error for debugging purposes
|
|
|
|
log.Printf("Error retrieving session cookie: %v", err)
|
|
|
|
}
|
2024-06-27 21:23:51 -06:00
|
|
|
return err == nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// GenerateRandomBytes returns securely generated random bytes.
|
|
|
|
func GenerateRandomBytes(n int) ([]byte, error) {
|
|
|
|
b := make([]byte, n)
|
|
|
|
_, err := rand.Read(b)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
return b, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
// GenerateSessionID generates a new session ID.
|
|
|
|
func GenerateSessionID() (string, error) {
|
|
|
|
bytes, err := GenerateRandomBytes(16)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
return hex.EncodeToString(bytes), nil
|
|
|
|
}
|