Session table work now
This commit is contained in:
parent
32374cfbb4
commit
632568ecf7
5 changed files with 179 additions and 93 deletions
|
@ -25,13 +25,16 @@ func SignInUserHandler(c echo.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the session cookie with the generated session ID
|
// Set the session cookie with the generated session ID
|
||||||
lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{
|
err = lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
Name: user.Name,
|
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
|
Name: user.Name,
|
||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
Roles: []string{"user"},
|
Roles: []string{"user"},
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return c.JSON(http.StatusInternalServerError, "Failed to create session")
|
||||||
|
}
|
||||||
|
|
||||||
// Proceed with login success logic
|
// Proceed with login success logic
|
||||||
c.Response().Header().Set("HX-Redirect", "/")
|
c.Response().Header().Set("HX-Redirect", "/")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"pollo/lib"
|
"pollo/lib"
|
||||||
|
|
||||||
|
@ -8,10 +9,19 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func SignOutUserHandler(c echo.Context) error {
|
func SignOutUserHandler(c echo.Context) error {
|
||||||
|
sessionData, err := lib.GetSessionCookie(c.Request(), "session")
|
||||||
|
if err == nil && sessionData != nil {
|
||||||
|
// Delete the session from the database
|
||||||
|
err = lib.DeleteSession(lib.GetDBPool(), sessionData.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error deleting session: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Clear the session cookie
|
// Clear the session cookie
|
||||||
lib.ClearSessionCookie(c.Response().Writer, "session")
|
lib.ClearSessionCookie(c.Response().Writer, "session")
|
||||||
|
|
||||||
// Proceed with login success logic
|
// Proceed with logout success logic
|
||||||
c.Response().Header().Set("HX-Redirect", "/")
|
c.Response().Header().Set("HX-Redirect", "/")
|
||||||
return c.NoContent(http.StatusOK)
|
return c.NoContent(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
85
lib/crypto.go
Normal file
85
lib/crypto.go
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
package lib
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Returns the first 32 bytes of the SHA-256 hash of the ENCRYPTION_KEY environment variable
|
||||||
|
func GetEncryptionKey() []byte {
|
||||||
|
key := []byte(os.Getenv("ENCRYPTION_KEY"))
|
||||||
|
hash := sha256.Sum256(key)
|
||||||
|
return hash[:32] // Use the first 32 bytes for AES-256
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt data using AES
|
||||||
|
func Encrypt(data []byte) (string, error) {
|
||||||
|
encryptionKey := GetEncryptionKey()
|
||||||
|
fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey))
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(encryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
nonce := make([]byte, gcm.NonceSize())
|
||||||
|
_, err = rand.Read(nonce)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
cipherText := gcm.Seal(nonce, nonce, data, nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decrypt decrypts the data using AES-GCM.
|
||||||
|
func Decrypt(encryptedString string) (string, error) {
|
||||||
|
encryptionKey := GetEncryptionKey()
|
||||||
|
|
||||||
|
data, err := base64.StdEncoding.DecodeString(encryptedString)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := aes.NewCipher(encryptionKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
gcm, err := cipher.NewGCM(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
nonceSize := gcm.NonceSize()
|
||||||
|
if len(data) < nonceSize {
|
||||||
|
return "", errors.New("ciphertext too short")
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
||||||
|
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(plaintext), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRandomBytes returns securely generated random bytes.
|
||||||
|
func GenerateRandomBytes(n int) ([]byte, error) {
|
||||||
|
b := make([]byte, n)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return b, nil
|
||||||
|
}
|
13
lib/db.go
13
lib/db.go
|
@ -81,6 +81,19 @@ func InitializeSchema(dbPool *pgxpool.Pool) error {
|
||||||
CREATE INDEX IF NOT EXISTS idx_rooms_userid ON rooms(userId);
|
CREATE INDEX IF NOT EXISTS idx_rooms_userid ON rooms(userId);
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Version: "6_create_sessions_table",
|
||||||
|
SQL: `
|
||||||
|
CREATE TABLE IF NOT EXISTS sessions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
expires_at TIMESTAMP NOT NULL,
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON sessions(user_id);
|
||||||
|
`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure the migrations table exists
|
// Ensure the migrations table exists
|
||||||
|
|
155
lib/session.go
155
lib/session.go
|
@ -1,20 +1,16 @@
|
||||||
package lib
|
package lib
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/aes"
|
"context"
|
||||||
"crypto/cipher"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
|
"github.com/jackc/pgx/v4/pgxpool"
|
||||||
"github.com/labstack/echo-contrib/session"
|
"github.com/labstack/echo-contrib/session"
|
||||||
"github.com/labstack/echo/v4"
|
"github.com/labstack/echo/v4"
|
||||||
)
|
)
|
||||||
|
@ -37,80 +33,27 @@ func InitSessionMiddleware() echo.MiddlewareFunc {
|
||||||
return session.Middleware(store)
|
return session.Middleware(store)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the first 32 bytes of the SHA-256 hash of the ENCRYPTION_KEY environment variable
|
|
||||||
func getEncryptionKey() []byte {
|
|
||||||
key := []byte(os.Getenv("ENCRYPTION_KEY"))
|
|
||||||
hash := sha256.Sum256(key)
|
|
||||||
return hash[:32] // Use the first 32 bytes for AES-256
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt data using AES
|
|
||||||
func encrypt(data []byte) (string, error) {
|
|
||||||
encryptionKey := getEncryptionKey()
|
|
||||||
fmt.Printf("Encryption Key Length: %d\n", len(encryptionKey))
|
|
||||||
|
|
||||||
block, err := aes.NewCipher(encryptionKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
gcm, err := cipher.NewGCM(block)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
nonce := make([]byte, gcm.NonceSize())
|
|
||||||
_, err = rand.Read(nonce)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
cipherText := gcm.Seal(nonce, nonce, data, nil)
|
|
||||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// decrypt decrypts the data using AES-GCM.
|
|
||||||
func decrypt(encryptedString string) (string, error) {
|
|
||||||
encryptionKey := getEncryptionKey()
|
|
||||||
|
|
||||||
data, err := base64.StdEncoding.DecodeString(encryptedString)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
block, err := aes.NewCipher(encryptionKey)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
gcm, err := cipher.NewGCM(block)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
nonceSize := gcm.NonceSize()
|
|
||||||
if len(data) < nonceSize {
|
|
||||||
return "", errors.New("ciphertext too short")
|
|
||||||
}
|
|
||||||
|
|
||||||
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
|
|
||||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(plaintext), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjusted SetSessionCookie to include more user info
|
// Adjusted SetSessionCookie to include more user info
|
||||||
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) {
|
func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionData) error {
|
||||||
|
// Create session in database
|
||||||
|
expiresAt := time.Now().Add(48 * time.Hour) // Set expiration to 1 hour from now
|
||||||
|
sessionID, err := CreateSession(GetDBPool(), sessionData.UserID, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionData.SessionID = sessionID
|
||||||
|
|
||||||
// Serialize session data
|
// Serialize session data
|
||||||
dataBytes, err := json.Marshal(sessionData)
|
dataBytes, err := json.Marshal(sessionData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Failed to serialize session data:", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt serialized session data
|
// Encrypt serialized session data
|
||||||
encryptedData, err := encrypt(dataBytes)
|
encryptedData, err := Encrypt(dataBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("Failed to encrypt session data:", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set cookie with encrypted data
|
// Set cookie with encrypted data
|
||||||
|
@ -121,8 +64,10 @@ func SetSessionCookie(w http.ResponseWriter, name string, sessionData SessionDat
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: os.Getenv("DEVMODE") != "true",
|
Secure: os.Getenv("DEVMODE") != "true",
|
||||||
SameSite: http.SameSiteStrictMode,
|
SameSite: http.SameSiteStrictMode,
|
||||||
MaxAge: 3600,
|
MaxAge: 3600, // 1 hour
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adjusted GetSessionCookie to return SessionData
|
// Adjusted GetSessionCookie to return SessionData
|
||||||
|
@ -133,7 +78,7 @@ func GetSessionCookie(r *http.Request, name string) (*SessionData, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt the cookie value
|
// Decrypt the cookie value
|
||||||
decryptedValue, err := decrypt(cookie.Value)
|
decryptedValue, err := Decrypt(cookie.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -169,10 +114,10 @@ func IsSignedIn(c echo.Context) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate the session data by checking if the user exists in the database
|
// Validate the session in the database
|
||||||
user, err := GetUserByID(GetDBPool(), sessionData.UserID)
|
validSessionData, err := ValidateSession(GetDBPool(), sessionData.SessionID)
|
||||||
if err != nil || user == nil {
|
if err != nil || validSessionData == nil {
|
||||||
log.Printf("Session refers to a non-existent user: %v", err)
|
log.Printf("Invalid session: %v", err)
|
||||||
ClearSessionCookie(c.Response().Writer, "session")
|
ClearSessionCookie(c.Response().Writer, "session")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -180,17 +125,6 @@ func IsSignedIn(c echo.Context) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateRandomBytes returns securely generated random bytes.
|
|
||||||
func GenerateRandomBytes(n int) ([]byte, error) {
|
|
||||||
b := make([]byte, n)
|
|
||||||
_, err := rand.Read(b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return b, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateSessionID generates a new session ID.
|
// GenerateSessionID generates a new session ID.
|
||||||
func GenerateSessionID() (string, error) {
|
func GenerateSessionID() (string, error) {
|
||||||
bytes, err := GenerateRandomBytes(16)
|
bytes, err := GenerateRandomBytes(16)
|
||||||
|
@ -199,3 +133,44 @@ func GenerateSessionID() (string, error) {
|
||||||
}
|
}
|
||||||
return hex.EncodeToString(bytes), nil
|
return hex.EncodeToString(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CreateSession(dbPool *pgxpool.Pool, userID string, expiresAt time.Time) (string, error) {
|
||||||
|
sessionID := GenerateNewID("session")
|
||||||
|
_, err := dbPool.Exec(context.Background(),
|
||||||
|
"INSERT INTO sessions (id, user_id, expires_at) VALUES ($1, $2, $3)",
|
||||||
|
sessionID, userID, expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return sessionID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateSession(dbPool *pgxpool.Pool, sessionID string) (*SessionData, error) {
|
||||||
|
var userID string
|
||||||
|
var expiresAt time.Time
|
||||||
|
err := dbPool.QueryRow(context.Background(),
|
||||||
|
"SELECT user_id, expires_at FROM sessions WHERE id = $1 AND expires_at > NOW()",
|
||||||
|
sessionID).Scan(&userID, &expiresAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := GetUserByID(dbPool, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SessionData{
|
||||||
|
SessionID: sessionID,
|
||||||
|
UserID: user.ID,
|
||||||
|
Name: user.Name,
|
||||||
|
Email: user.Email,
|
||||||
|
Roles: []string{"user"},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteSession(dbPool *pgxpool.Pool, sessionID string) error {
|
||||||
|
_, err := dbPool.Exec(context.Background(),
|
||||||
|
"DELETE FROM sessions WHERE id = $1", sessionID)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue