From 7978dd131a893899b331a96712f4b60b9548fc93 Mon Sep 17 00:00:00 2001 From: Atridad Lahiji Date: Fri, 28 Jun 2024 09:11:03 -0600 Subject: [PATCH] Auth fixes --- api/register.go | 13 ++++++++++++- lib/auth.go | 16 ++++++++++++++++ lib/db.go | 1 + lib/session.go | 16 +++++++++++++--- pages/dashboard.go | 4 ++-- pages/templates/dashboard.html | 2 +- pages/templates/register.html | 12 +++++++----- pages/templates/signin.html | 1 - 8 files changed, 52 insertions(+), 13 deletions(-) diff --git a/api/register.go b/api/register.go index a522680..fe42679 100644 --- a/api/register.go +++ b/api/register.go @@ -8,15 +8,25 @@ import ( ) func RegisterUserHandler(c echo.Context) error { + name := c.FormValue("name") email := c.FormValue("email") password := c.FormValue("password") + // Check if the email already exists + existingUser, err := lib.GetUserByEmail(lib.GetDBPool(), email) + if err == nil && existingUser != nil { + // User with the given email already exists + errorMessage := `
An account with this email already exists!
` + return c.HTML(http.StatusBadRequest, errorMessage) + } + + // Proceed with the existing registration logic hashedPassword, err := lib.HashPassword(password) if err != nil { return c.JSON(http.StatusInternalServerError, "Failed to hash password") } - user := lib.User{Email: email, Password: hashedPassword} + user := lib.User{Name: name, Email: email, Password: hashedPassword} if err := lib.SaveUser(lib.GetDBPool(), &user); err != nil { return c.JSON(http.StatusInternalServerError, "Failed to save user") } @@ -30,6 +40,7 @@ func RegisterUserHandler(c echo.Context) error { lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{ SessionID: sessionID, UserID: user.ID, + Name: user.Name, Email: user.Email, Roles: []string{"user"}, }) diff --git a/lib/auth.go b/lib/auth.go index 5043251..f8852bf 100644 --- a/lib/auth.go +++ b/lib/auth.go @@ -12,6 +12,7 @@ import ( type User struct { ID string + Name string Email string Password string } @@ -43,6 +44,21 @@ func GetUserByEmail(dbPool *pgxpool.Pool, email string) (*User, error) { return &user, nil } +// GetUserByID fetches a user by ID from the database. +func GetUserByID(dbPool *pgxpool.Pool, id string) (*User, error) { + if dbPool == nil { + return nil, errors.New("database connection pool is not initialized") + } + + var user User + // Ensure the ID is being scanned as a string. + err := dbPool.QueryRow(context.Background(), "SELECT id::text, email, password FROM users WHERE id = $1", id).Scan(&user.ID, &user.Email, &user.Password) + if err != nil { + return nil, err + } + return &user, nil +} + // SaveUser saves a new user to the database. func SaveUser(dbPool *pgxpool.Pool, user *User) error { if dbPool == nil { diff --git a/lib/db.go b/lib/db.go index 417639c..cc0956c 100644 --- a/lib/db.go +++ b/lib/db.go @@ -26,6 +26,7 @@ func InitializeSchema(dbPool *pgxpool.Pool) error { const schemaSQL = ` CREATE TABLE IF NOT EXISTS users ( id SERIAL PRIMARY KEY, + name VARCHAR(255) NOT NULL, email VARCHAR(255) UNIQUE NOT NULL, password VARCHAR(255) NOT NULL );` diff --git a/lib/session.go b/lib/session.go index 59b1d26..11a26ee 100644 --- a/lib/session.go +++ b/lib/session.go @@ -21,6 +21,7 @@ import ( type SessionData struct { SessionID string `json:"sessionID"` UserID string `json:"userId"` + Name string `json:"name"` Email string `json:"email"` Roles []string `json:"roles"` } @@ -154,12 +155,21 @@ func ClearSessionCookie(w http.ResponseWriter, name string) { // Checks if the user is signed in by checking the session cookie func IsSignedIn(c echo.Context) bool { - _, err := GetSessionCookie(c.Request(), "session") + sessionData, err := GetSessionCookie(c.Request(), "session") if err != nil { - // Log the error for debugging purposes log.Printf("Error retrieving session cookie: %v", err) + return false } - return err == nil + + // Validate the session data by checking if the user exists in the database + user, err := GetUserByID(GetDBPool(), sessionData.UserID) + if err != nil || user == nil { + log.Printf("Session refers to a non-existent user: %v", err) + ClearSessionCookie(c.Response().Writer, "session") + return false + } + + return true } // GenerateRandomBytes returns securely generated random bytes. diff --git a/pages/dashboard.go b/pages/dashboard.go index 7a5a265..ab57c22 100644 --- a/pages/dashboard.go +++ b/pages/dashboard.go @@ -8,7 +8,7 @@ import ( type DashboardProps struct { IsLoggedIn bool - Email string + Name string } func Dashboard(c echo.Context) error { @@ -19,7 +19,7 @@ func Dashboard(c echo.Context) error { props := DashboardProps{ IsLoggedIn: lib.IsSignedIn(c), - Email: currentSession.Email, + Name: currentSession.Name, } // Specify the partials used by this page diff --git a/pages/templates/dashboard.html b/pages/templates/dashboard.html index f73a227..2e5b99f 100644 --- a/pages/templates/dashboard.html +++ b/pages/templates/dashboard.html @@ -13,7 +13,7 @@ Pollo // Dashboard {{define "main"}}

- Hi, {{.Email}}! + Hi, {{.Name}}!