diff --git a/api/register.go b/api/register.go
index a522680..fe42679 100644
--- a/api/register.go
+++ b/api/register.go
@@ -8,15 +8,25 @@ import (
)
func RegisterUserHandler(c echo.Context) error {
+ name := c.FormValue("name")
email := c.FormValue("email")
password := c.FormValue("password")
+ // Check if the email already exists
+ existingUser, err := lib.GetUserByEmail(lib.GetDBPool(), email)
+ if err == nil && existingUser != nil {
+ // User with the given email already exists
+ errorMessage := `
An account with this email already exists!
`
+ return c.HTML(http.StatusBadRequest, errorMessage)
+ }
+
+ // Proceed with the existing registration logic
hashedPassword, err := lib.HashPassword(password)
if err != nil {
return c.JSON(http.StatusInternalServerError, "Failed to hash password")
}
- user := lib.User{Email: email, Password: hashedPassword}
+ user := lib.User{Name: name, Email: email, Password: hashedPassword}
if err := lib.SaveUser(lib.GetDBPool(), &user); err != nil {
return c.JSON(http.StatusInternalServerError, "Failed to save user")
}
@@ -30,6 +40,7 @@ func RegisterUserHandler(c echo.Context) error {
lib.SetSessionCookie(c.Response().Writer, "session", lib.SessionData{
SessionID: sessionID,
UserID: user.ID,
+ Name: user.Name,
Email: user.Email,
Roles: []string{"user"},
})
diff --git a/lib/auth.go b/lib/auth.go
index 5043251..f8852bf 100644
--- a/lib/auth.go
+++ b/lib/auth.go
@@ -12,6 +12,7 @@ import (
type User struct {
ID string
+ Name string
Email string
Password string
}
@@ -43,6 +44,21 @@ func GetUserByEmail(dbPool *pgxpool.Pool, email string) (*User, error) {
return &user, nil
}
+// GetUserByID fetches a user by ID from the database.
+func GetUserByID(dbPool *pgxpool.Pool, id string) (*User, error) {
+ if dbPool == nil {
+ return nil, errors.New("database connection pool is not initialized")
+ }
+
+ var user User
+ // Ensure the ID is being scanned as a string.
+ err := dbPool.QueryRow(context.Background(), "SELECT id::text, email, password FROM users WHERE id = $1", id).Scan(&user.ID, &user.Email, &user.Password)
+ if err != nil {
+ return nil, err
+ }
+ return &user, nil
+}
+
// SaveUser saves a new user to the database.
func SaveUser(dbPool *pgxpool.Pool, user *User) error {
if dbPool == nil {
diff --git a/lib/db.go b/lib/db.go
index 417639c..cc0956c 100644
--- a/lib/db.go
+++ b/lib/db.go
@@ -26,6 +26,7 @@ func InitializeSchema(dbPool *pgxpool.Pool) error {
const schemaSQL = `
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
+ name VARCHAR(255) NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
password VARCHAR(255) NOT NULL
);`
diff --git a/lib/session.go b/lib/session.go
index 59b1d26..11a26ee 100644
--- a/lib/session.go
+++ b/lib/session.go
@@ -21,6 +21,7 @@ import (
type SessionData struct {
SessionID string `json:"sessionID"`
UserID string `json:"userId"`
+ Name string `json:"name"`
Email string `json:"email"`
Roles []string `json:"roles"`
}
@@ -154,12 +155,21 @@ func ClearSessionCookie(w http.ResponseWriter, name string) {
// Checks if the user is signed in by checking the session cookie
func IsSignedIn(c echo.Context) bool {
- _, err := GetSessionCookie(c.Request(), "session")
+ sessionData, err := GetSessionCookie(c.Request(), "session")
if err != nil {
- // Log the error for debugging purposes
log.Printf("Error retrieving session cookie: %v", err)
+ return false
}
- return err == nil
+
+ // Validate the session data by checking if the user exists in the database
+ user, err := GetUserByID(GetDBPool(), sessionData.UserID)
+ if err != nil || user == nil {
+ log.Printf("Session refers to a non-existent user: %v", err)
+ ClearSessionCookie(c.Response().Writer, "session")
+ return false
+ }
+
+ return true
}
// GenerateRandomBytes returns securely generated random bytes.
diff --git a/pages/dashboard.go b/pages/dashboard.go
index 7a5a265..ab57c22 100644
--- a/pages/dashboard.go
+++ b/pages/dashboard.go
@@ -8,7 +8,7 @@ import (
type DashboardProps struct {
IsLoggedIn bool
- Email string
+ Name string
}
func Dashboard(c echo.Context) error {
@@ -19,7 +19,7 @@ func Dashboard(c echo.Context) error {
props := DashboardProps{
IsLoggedIn: lib.IsSignedIn(c),
- Email: currentSession.Email,
+ Name: currentSession.Name,
}
// Specify the partials used by this page
diff --git a/pages/templates/dashboard.html b/pages/templates/dashboard.html
index f73a227..2e5b99f 100644
--- a/pages/templates/dashboard.html
+++ b/pages/templates/dashboard.html
@@ -13,7 +13,7 @@ Pollo // Dashboard
{{define "main"}}
- Hi, {{.Email}}!
+ Hi, {{.Name}}!