pollo/lib/session.go
2024-11-21 00:48:45 -06:00

176 lines
4.1 KiB
Go

package lib
import (
"database/sql"
"encoding/hex"
"encoding/json"
"log"
"net/http"
"os"
"time"
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
)
type SessionData struct {
SessionID string `json:"sessionID"`
UserID string `json:"userId"`
Name string `json:"name"`
Email string `json:"email"`
Roles []string `json:"roles"`
}
func InitSessionMiddleware() echo.MiddlewareFunc {
authSecret := os.Getenv("AUTH_SECRET")
if authSecret == "" {
log.Fatal("AUTH_SECRET environment variable is not set.")
}
store := sessions.NewCookieStore([]byte(authSecret))
return session.Middleware(store)
}
// Adjusted SetSessionCookie to include more user info
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) error {
// Create session in database
expiresAt := time.Now().Add(48 * time.Hour) // Set expiration to 48 hours from now
sessionID, err := CreateSession(GetDB(), sessionData.UserID, expiresAt)
if err != nil {
return err
}
sessionData.SessionID = sessionID
// Serialize session data
dataBytes, err := json.Marshal(sessionData)
if err != nil {
return err
}
// Encrypt serialized session data
encryptedData, err := Encrypt(dataBytes)
if err != nil {
return err
}
// Set cookie with encrypted data
http.SetCookie(w, &http.Cookie{
Name: name,
Value: encryptedData,
Path: "/",
HttpOnly: true,
Secure: os.Getenv("DEVMODE") != "true",
SameSite: http.SameSiteStrictMode,
MaxAge: 3600, // 1 hour
})
return nil
}
// Adjusted GetSessionCookie to return SessionData
func GetSessionCookie(r *http.Request, name string) (*SessionData, error) {
cookie, err := r.Cookie(name)
if err != nil {
return nil, err
}
// Decrypt the cookie value
decryptedValue, err := Decrypt(cookie.Value)
if err != nil {
return nil, err
}
// Deserialize session data
var sessionData SessionData
err = json.Unmarshal([]byte(decryptedValue), &sessionData)
if err != nil {
return nil, err
}
return &sessionData, nil
}
// ClearSessionCookie clears the session cookie from the client's browser
func ClearSessionCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: "/",
HttpOnly: true,
Secure: os.Getenv("DEVMODE") != "true",
SameSite: http.SameSiteStrictMode,
MaxAge: -1, // This will delete the cookie
})
}
// Checks if the user is signed in by checking the session cookie
func IsSignedIn(c echo.Context) bool {
sessionData, err := GetSessionCookie(c.Request(), "session")
if err != nil {
log.Printf("Error retrieving session cookie: %v", err)
return false
}
// Validate the session in the database
validSessionData, err := ValidateSession(GetDB(), sessionData.SessionID)
if err != nil || validSessionData == nil {
log.Printf("Invalid session: %v", err)
ClearSessionCookie(c.Response().Writer, "session")
return false
}
return true
}
// GenerateSessionID generates a new session ID.
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func CreateSession(db *sql.DB, userID string, expiresAt time.Time) (string, error) {
sessionID := GenerateNewID("session")
_, err := db.Exec(
"INSERT INTO sessions (id, user_id, expires_at) VALUES (?, ?, ?)",
sessionID, userID, expiresAt)
if err != nil {
return "", err
}
return sessionID, nil
}
func ValidateSession(db *sql.DB, sessionID string) (*SessionData, error) {
var userID string
var expiresAt time.Time
err := db.QueryRow(
"SELECT user_id, expires_at FROM sessions WHERE id = ? AND expires_at > datetime('now')",
sessionID).Scan(&userID, &expiresAt)
if err != nil {
return nil, err
}
user, err := GetUserByID(db, userID)
if err != nil {
return nil, err
}
return &SessionData{
SessionID: sessionID,
UserID: user.ID,
Name: user.Name,
Email: user.Email,
Roles: []string{"user"},
}, nil
}
func DeleteSession(db *sql.DB, sessionID string) error {
_, err := db.Exec(
"DELETE FROM sessions WHERE id = ?",
sessionID)
return err
}