Session table work now
This commit is contained in:
parent
32374cfbb4
commit
632568ecf7
5 changed files with 179 additions and 93 deletions
|
@ -25,13 +25,16 @@ func SignInUserHandler(c echo.Context) error {
|
|||
}
|
||||
|
||||
// Set the session cookie with the generated session ID
|
||||
lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{
|
||||
err = lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{
|
||||
SessionID: sessionID,
|
||||
Name: user.Name,
|
||||
UserID: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Roles: []string{"user"},
|
||||
})
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, "Failed to create session")
|
||||
}
|
||||
|
||||
// Proceed with login success logic
|
||||
c.Response().Header().Set("HX-Redirect", "/")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"pollo/lib"
|
||||
|
||||
|
@ -8,10 +9,19 @@ import (
|
|||
)
|
||||
|
||||
func SignOutUserHandler(c echo.Context) error {
|
||||
sessionData, err := lib.GetSessionCookie(c.Request(), "session")
|
||||
if err == nil && sessionData != nil {
|
||||
// Delete the session from the database
|
||||
err = lib.DeleteSession(lib.GetDBPool(), sessionData.SessionID)
|
||||
if err != nil {
|
||||
log.Printf("Error deleting session: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear the session cookie
|
||||
lib.ClearSessionCookie(c.Response().Writer, "session")
|
||||
|
||||
// Proceed with login success logic
|
||||
// Proceed with logout success logic
|
||||
c.Response().Header().Set("HX-Redirect", "/")
|
||||
return c.NoContent(http.StatusOK)
|
||||
}
|
||||
|
|
85
lib/crypto.go
Normal file
85
lib/crypto.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Returns the first 32 bytes of the SHA-256 hash of the ENCRYPTION_KEY environment variable
|
||||
func GetEncryptionKey() []byte {
|
||||
key := []byte(os.Getenv("ENCRYPTION_KEY"))
|
||||
hash := sha256.Sum256(key)
|
||||
return hash[:32] // Use the first 32 bytes for AES-256
|
||||
}
|
||||
|
||||
// Encrypt data using AES
|
||||
func Encrypt(data []byte) (string, error) {
|
||||
encryptionKey := GetEncryptionKey()
|
||||
fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey))
|
||||
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
_, err = rand.Read(nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cipherText := gcm.Seal(nonce, nonce, data, nil)
|
||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// decrypt decrypts the data using AES-GCM.
|
||||
func Decrypt(encryptedString string) (string, error) {
|
||||
encryptionKey := GetEncryptionKey()
|
||||
|
||||
data, err := base64.StdEncoding.DecodeString(encryptedString)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// GenerateRandomBytes returns securely generated random bytes.
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
13
lib/db.go
13
lib/db.go
|
@ -81,6 +81,19 @@ func InitializeSchema(dbPool *pgxpool.Pool) error {
|
|||
CREATE INDEX IF NOT EXISTS idx_rooms_userid ON rooms(userId);
|
||||
`,
|
||||
},
|
||||
{
|
||||
Version: "6_create_sessions_table",
|
||||
SQL: `
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
// Ensure the migrations table exists
|
||||
|
|
155
lib/session.go
155
lib/session.go
|
@ -1,20 +1,16 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/labstack/echo-contrib/session"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
@ -37,80 +33,27 @@ func InitSessionMiddleware() echo.MiddlewareFunc {
|
|||
return session.Middleware(store)
|
||||
}
|
||||
|
||||
// Returns the first 32 bytes of the SHA-256 hash of the ENCRYPTION_KEY environment variable
|
||||
func getEncryptionKey() []byte {
|
||||
key := []byte(os.Getenv("ENCRYPTION_KEY"))
|
||||
hash := sha256.Sum256(key)
|
||||
return hash[:32] // Use the first 32 bytes for AES-256
|
||||
}
|
||||
|
||||
// Encrypt data using AES
|
||||
func encrypt(data []byte) (string, error) {
|
||||
encryptionKey := getEncryptionKey()
|
||||
fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey))
|
||||
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
_, err = rand.Read(nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cipherText := gcm.Seal(nonce, nonce, data, nil)
|
||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// decrypt decrypts the data using AES-GCM.
|
||||
func decrypt(encryptedString string) (string, error) {
|
||||
encryptionKey := getEncryptionKey()
|
||||
|
||||
data, err := base64.StdEncoding.DecodeString(encryptedString)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(encryptionKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
// Adjusted SetSessionCookie to include more user info
|
||||
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) {
|
||||
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) error {
|
||||
// Create session in database
|
||||
expiresAt := time.Now().Add(48 * time.Hour) // Set expiration to 1 hour from now
|
||||
sessionID, err := CreateSession(GetDBPool(), sessionData.UserID, expiresAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionData.SessionID = sessionID
|
||||
|
||||
// Serialize session data
|
||||
dataBytes, err := json.Marshal(sessionData)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to serialize session data:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Encrypt serialized session data
|
||||
encryptedData, err := encrypt(dataBytes)
|
||||
encryptedData, err := Encrypt(dataBytes)
|
||||
if err != nil {
|
||||
log.Fatal("Failed to encrypt session data:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Set cookie with encrypted data
|
||||
|
@ -121,8 +64,10 @@ func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionDat
|
|||
HttpOnly: true,
|
||||
Secure: os.Getenv("DEVMODE") != "true",
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
MaxAge: 3600,
|
||||
MaxAge: 3600, // 1 hour
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Adjusted GetSessionCookie to return SessionData
|
||||
|
@ -133,7 +78,7 @@ func GetSessionCookie(r *http.Request, name string) (*SessionData, error) {
|
|||
}
|
||||
|
||||
// Decrypt the cookie value
|
||||
decryptedValue, err := decrypt(cookie.Value)
|
||||
decryptedValue, err := Decrypt(cookie.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -169,10 +114,10 @@ func IsSignedIn(c echo.Context) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
// Validate the session data by checking if the user exists in the database
|
||||
user, err := GetUserByID(GetDBPool(), sessionData.UserID)
|
||||
if err != nil || user == nil {
|
||||
log.Printf("Session refers to a non-existent user: %v", err)
|
||||
// Validate the session in the database
|
||||
validSessionData, err := ValidateSession(GetDBPool(), sessionData.SessionID)
|
||||
if err != nil || validSessionData == nil {
|
||||
log.Printf("Invalid session: %v", err)
|
||||
ClearSessionCookie(c.Response().Writer, "session")
|
||||
return false
|
||||
}
|
||||
|
@ -180,17 +125,6 @@ func IsSignedIn(c echo.Context) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
// GenerateRandomBytes returns securely generated random bytes.
|
||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// GenerateSessionID generates a new session ID.
|
||||
func GenerateSessionID() (string, error) {
|
||||
bytes, err := GenerateRandomBytes(16)
|
||||
|
@ -199,3 +133,44 @@ func GenerateSessionID() (string, error) {
|
|||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func CreateSession(dbPool *pgxpool.Pool, userID string, expiresAt time.Time) (string, error) {
|
||||
sessionID := GenerateNewID("session")
|
||||
_, err := dbPool.Exec(context.Background(),
|
||||
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
|
||||
sessionID, userID, expiresAt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return sessionID, nil
|
||||
}
|
||||
|
||||
func ValidateSession(dbPool *pgxpool.Pool, sessionID string) (*SessionData, error) {
|
||||
var userID string
|
||||
var expiresAt time.Time
|
||||
err := dbPool.QueryRow(context.Background(),
|
||||
"SELECT user_id, expires_at FROM sessions WHERE id = $1 AND expires_at > NOW()",
|
||||
sessionID).Scan(&userID, &expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := GetUserByID(dbPool, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SessionData{
|
||||
SessionID: sessionID,
|
||||
UserID: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Roles: []string{"user"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DeleteSession(dbPool *pgxpool.Pool, sessionID string) error {
|
||||
_, err := dbPool.Exec(context.Background(),
|
||||
"DELETE FROM sessions WHERE id = $1", sessionID)
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue