pollo/lib/session.go

177 lines
4.1 KiB
Go
Raw Normal View History

2024-06-27 21:23:51 -06:00
package lib
import (
2024-11-21 00:48:45 -06:00
"database/sql"
2024-06-27 21:23:51 -06:00
"encoding/hex"
2024-06-28 00:35:58 -06:00
"encoding/json"
2024-06-27 21:23:51 -06:00
"log"
"net/http"
"os"
2024-09-16 13:46:25 -06:00
"time"
2024-06-27 21:23:51 -06:00
"github.com/gorilla/sessions"
"github.com/labstack/echo-contrib/session"
"github.com/labstack/echo/v4"
)
2024-06-28 00:35:58 -06:00
type SessionData struct {
SessionID string `json:"sessionID"`
UserID string `json:"userId"`
2024-06-28 09:11:03 -06:00
Name string `json:"name"`
2024-06-28 00:35:58 -06:00
Email string `json:"email"`
Roles []string `json:"roles"`
}
2024-06-27 21:23:51 -06:00
func InitSessionMiddleware() echo.MiddlewareFunc {
authSecret := os.Getenv("AUTH_SECRET")
if authSecret == "" {
log.Fatal("AUTH_SECRET environment variable is not set.")
}
store := sessions.NewCookieStore([]byte(authSecret))
return session.Middleware(store)
}
2024-09-16 13:46:25 -06:00
// Adjusted SetSessionCookie to include more user info
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) error {
// Create session in database
2024-11-21 00:48:45 -06:00
expiresAt := time.Now().Add(48 * time.Hour) // Set expiration to 48 hours from now
sessionID, err := CreateSession(GetDB(), sessionData.UserID, expiresAt)
2024-06-28 00:35:58 -06:00
if err != nil {
2024-09-16 13:46:25 -06:00
return err
2024-06-28 00:35:58 -06:00
}
2024-09-16 13:46:25 -06:00
sessionData.SessionID = sessionID
2024-06-28 00:35:58 -06:00
// Serialize session data
dataBytes, err := json.Marshal(sessionData)
if err != nil {
2024-09-16 13:46:25 -06:00
return err
2024-06-28 00:35:58 -06:00
}
2024-06-27 21:23:51 -06:00
2024-06-28 00:35:58 -06:00
// Encrypt serialized session data
2024-09-16 13:46:25 -06:00
encryptedData, err := Encrypt(dataBytes)
2024-06-28 00:35:58 -06:00
if err != nil {
2024-09-16 13:46:25 -06:00
return err
2024-06-28 00:35:58 -06:00
}
// Set cookie with encrypted data
2024-06-27 21:23:51 -06:00
http.SetCookie(w, &http.Cookie{
Name: name,
2024-06-28 00:35:58 -06:00
Value: encryptedData,
2024-06-27 21:23:51 -06:00
Path: "/",
HttpOnly: true,
2024-06-28 00:35:58 -06:00
Secure: os.Getenv("DEVMODE") != "true",
2024-06-27 21:23:51 -06:00
SameSite: http.SameSiteStrictMode,
2024-09-16 13:46:25 -06:00
MaxAge: 3600, // 1 hour
2024-06-27 21:23:51 -06:00
})
2024-09-16 13:46:25 -06:00
return nil
2024-06-27 21:23:51 -06:00
}
2024-06-28 00:35:58 -06:00
// Adjusted GetSessionCookie to return SessionData
func GetSessionCookie(r *http.Request, name string) (*SessionData, error) {
2024-06-27 21:23:51 -06:00
cookie, err := r.Cookie(name)
if err != nil {
2024-06-28 00:35:58 -06:00
return nil, err
2024-06-27 21:23:51 -06:00
}
2024-06-28 00:35:58 -06:00
// Decrypt the cookie value
2024-09-16 13:46:25 -06:00
decryptedValue, err := Decrypt(cookie.Value)
2024-06-28 00:35:58 -06:00
if err != nil {
return nil, err
}
// Deserialize session data
var sessionData SessionData
err = json.Unmarshal([]byte(decryptedValue), &sessionData)
if err != nil {
return nil, err
}
return &sessionData, nil
2024-06-27 21:23:51 -06:00
}
2024-06-27 21:51:00 -06:00
// ClearSessionCookie clears the session cookie from the client's browser
func ClearSessionCookie(w http.ResponseWriter, name string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
Path: "/",
HttpOnly: true,
Secure: os.Getenv("DEVMODE") != "true",
SameSite: http.SameSiteStrictMode,
2024-06-28 00:35:58 -06:00
MaxAge: -1, // This will delete the cookie
2024-06-27 21:51:00 -06:00
})
}
2024-06-27 21:23:51 -06:00
// Checks if the user is signed in by checking the session cookie
func IsSignedIn(c echo.Context) bool {
2024-06-28 09:11:03 -06:00
sessionData, err := GetSessionCookie(c.Request(), "session")
2024-06-28 00:35:58 -06:00
if err != nil {
log.Printf("Error retrieving session cookie: %v", err)
2024-06-28 09:11:03 -06:00
return false
2024-06-28 00:35:58 -06:00
}
2024-06-28 09:11:03 -06:00
2024-09-16 13:46:25 -06:00
// Validate the session in the database
2024-11-21 00:48:45 -06:00
validSessionData, err := ValidateSession(GetDB(), sessionData.SessionID)
2024-09-16 13:46:25 -06:00
if err != nil || validSessionData == nil {
log.Printf("Invalid session: %v", err)
2024-06-28 09:11:03 -06:00
ClearSessionCookie(c.Response().Writer, "session")
return false
}
return true
2024-06-27 21:23:51 -06:00
}
// GenerateSessionID generates a new session ID.
func GenerateSessionID() (string, error) {
bytes, err := GenerateRandomBytes(16)
if err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
2024-09-16 13:46:25 -06:00
2024-11-21 00:48:45 -06:00
func CreateSession(db *sql.DB, userID string, expiresAt time.Time) (string, error) {
2024-09-16 13:46:25 -06:00
sessionID := GenerateNewID("session")
2024-11-21 00:48:45 -06:00
_, err := db.Exec(
"INSERT INTO sessions (id, user_id, expires_at) VALUES (?, ?, ?)",
2024-09-16 13:46:25 -06:00
sessionID, userID, expiresAt)
if err != nil {
return "", err
}
return sessionID, nil
}
2024-11-21 00:48:45 -06:00
func ValidateSession(db *sql.DB, sessionID string) (*SessionData, error) {
2024-09-16 13:46:25 -06:00
var userID string
var expiresAt time.Time
2024-11-21 00:48:45 -06:00
err := db.QueryRow(
"SELECT user_id, expires_at FROM sessions WHERE id = ? AND expires_at > datetime('now')",
2024-09-16 13:46:25 -06:00
sessionID).Scan(&userID, &expiresAt)
if err != nil {
return nil, err
}
2024-11-21 00:48:45 -06:00
user, err := GetUserByID(db, userID)
2024-09-16 13:46:25 -06:00
if err != nil {
return nil, err
}
return &SessionData{
SessionID: sessionID,
UserID: user.ID,
Name: user.Name,
Email: user.Email,
Roles: []string{"user"},
}, nil
}
2024-11-21 00:48:45 -06:00
func DeleteSession(db *sql.DB, sessionID string) error {
_, err := db.Exec(
"DELETE FROM sessions WHERE id = ?",
sessionID)
2024-09-16 13:46:25 -06:00
return err
}