diff --git a/.air.toml b/.air.toml index fa4b208..48b2dce 100644 --- a/.air.toml +++ b/.air.toml @@ -3,7 +3,7 @@ testdata_dir = "testdata" tmp_dir = "tmp" [build] - args_bin = [] + args_bin = ["-ip", "127.0.0.1", "-port", "3000"] bin = "./tmp/main" pre_cmd = [] cmd = "go build -o ./tmp/main . & cd stylegen && ./gen.sh" diff --git a/api/sse.go b/api/sse.go index b080cba..00eac85 100644 --- a/api/sse.go +++ b/api/sse.go @@ -1,15 +1,19 @@ package api import ( + "errors" "fmt" - "log" "time" "github.com/labstack/echo/v4" "goth.stack/lib" ) -func SSE(c echo.Context) error { +func SSE(c echo.Context, pubSub lib.PubSub) error { + if pubSub == nil { + return errors.New("pubSub is nil") + } + channel := c.QueryParam("channel") if channel == "" { channel = "default" @@ -18,41 +22,29 @@ func SSE(c echo.Context) error { // Use the request context, which is cancelled when the client disconnects ctx := c.Request().Context() - pubsub, _ := lib.Subscribe(lib.RedisClient, channel) + pubsub, err := pubSub.SubscribeToChannel(channel) + if err != nil { + return fmt.Errorf("failed to subscribe to channel: %w", err) + } - c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") - c.Response().Header().Set(echo.HeaderConnection, "keep-alive") - c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") + lib.SetSSEHeaders(c) // Create a ticker that fires every 15 seconds - ticker := time.NewTicker(30 * time.Second) + ticker := lib.CreateTickerAndKeepAlive(c, 30*time.Second) defer ticker.Stop() + // Create a client channel and add it to the SSE server + client := make(chan string) + lib.SSEServer.AddClient(channel, client) + defer lib.SSEServer.RemoveClient(channel, client) + + go lib.HandleIncomingMessages(c, pubsub, client) + for { select { case <-ctx.Done(): // If the client has disconnected, stop the loop return nil - case <-ticker.C: - // Every 30 seconds, send a comment to keep the connection alive - if _, err := c.Response().Write([]byte(": keep-alive\n\n")); err != nil { - return err - } - c.Response().Flush() - default: - // Handle incoming messages as before - msg, err := pubsub.ReceiveMessage(ctx) - if err != nil { - log.Printf("Failed to receive message: %v", err) - continue - } - - data := fmt.Sprintf("data: %s\n\n", msg.Payload) - if _, err := c.Response().Write([]byte(data)); err != nil { - return err - } - - c.Response().Flush() } } } diff --git a/api/ssedemosend.go b/api/ssedemosend.go index 2416b74..db4139f 100644 --- a/api/ssedemosend.go +++ b/api/ssedemosend.go @@ -7,7 +7,7 @@ import ( "goth.stack/lib" ) -func SSEDemoSend(c echo.Context) error { +func SSEDemoSend(c echo.Context, pubSub lib.PubSub) error { channel := c.QueryParam("channel") if channel == "" { channel = "default" @@ -30,8 +30,7 @@ func SSEDemoSend(c echo.Context) error { return c.JSON(http.StatusBadRequest, map[string]string{"error": "message parameter is required"}) } - // Send message - lib.SendSSE("default", message) + lib.SendSSE(c.Request().Context(), pubSub, "default", message) return c.JSON(http.StatusOK, map[string]string{"status": "message sent"}) } diff --git a/lib/types.go b/lib/links.go similarity index 50% rename from lib/types.go rename to lib/links.go index 912f1ea..36a81ea 100644 --- a/lib/types.go +++ b/lib/links.go @@ -19,19 +19,4 @@ type CardLink struct { Internal bool } -type Post struct { - Content template.HTML - Name string - Date string - Tags []string -} -type FrontMatter struct { - Name string - Date string - Tags []string -} -type PubSubMessage struct { - Channel string `json:"channel"` - Data string `json:"data"` -} diff --git a/lib/localpubsub.go b/lib/localpubsub.go new file mode 100644 index 0000000..fb54eaf --- /dev/null +++ b/lib/localpubsub.go @@ -0,0 +1,86 @@ +package lib + +import ( + "context" + "log" + "sync" + "time" +) + +type LocalPubSub struct { + subscribers map[string][]chan Message + lock sync.RWMutex +} + +type LocalPubSubMessage struct { + messages <-chan Message +} + +func (ps *LocalPubSub) SubscribeToChannel(channel string) (PubSubMessage, error) { + ps.lock.Lock() + defer ps.lock.Unlock() + + if ps.subscribers == nil { + ps.subscribers = make(map[string][]chan Message) + } + + ch := make(chan Message, 100) + ps.subscribers[channel] = append(ps.subscribers[channel], ch) + + log.Printf("Subscribed to channel %s", channel) + + return &LocalPubSubMessage{messages: ch}, nil +} + +func (ps *LocalPubSub) PublishToChannel(channel string, message string) error { + ps.lock.RLock() + defer ps.lock.RUnlock() + + if subscribers, ok := ps.subscribers[channel]; ok { + log.Printf("Publishing message to channel %s: %s", channel, message) + for _, ch := range subscribers { + ch <- Message{Payload: message} + } + } else { + log.Printf("No subscribers for channel %s", channel) + } + + return nil +} + +func (m *LocalPubSubMessage) ReceiveMessage(ctx context.Context) (*Message, error) { + for { + select { + case <-ctx.Done(): + // The client has disconnected. Stop trying to send messages. + return nil, ctx.Err() + case msg := <-m.messages: + // A message has been received. Send it to the client. + log.Printf("Received message: %s", msg.Payload) + return &msg, nil + case <-time.After(30 * time.Second): + // No message has been received for 30 seconds. Send a keep-alive message. + return &Message{Payload: "keep-alive"}, nil + } + } +} + +func (ps *LocalPubSub) UnsubscribeFromChannel(channel string, ch <-chan Message) { + ps.lock.Lock() + defer ps.lock.Unlock() + + subscribers := ps.subscribers[channel] + for i, subscriber := range subscribers { + if subscriber == ch { + // Remove the subscriber from the slice + subscribers = append(subscribers[:i], subscribers[i+1:]...) + break + } + } + + if len(subscribers) == 0 { + delete(ps.subscribers, channel) + } else { + ps.subscribers[channel] = subscribers + } +} diff --git a/lib/markdown.go b/lib/markdown.go index a36d7f3..9221752 100644 --- a/lib/markdown.go +++ b/lib/markdown.go @@ -12,6 +12,12 @@ import ( "gopkg.in/yaml.v2" ) +type FrontMatter struct { + Name string + Date string + Tags []string +} + func ExtractFrontMatter(file os.DirEntry, dir string) (CardLink, error) { f, err := os.Open(dir + file.Name()) if err != nil { diff --git a/lib/markdown_test.go b/lib/markdown_test.go deleted file mode 100644 index 190ccf9..0000000 --- a/lib/markdown_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package lib_test - -import ( - "io/fs" - "log" - "os" - "path/filepath" - "testing" - - "github.com/alecthomas/assert/v2" - "goth.stack/lib" -) - -func TestExtractFrontMatter(t *testing.T) { - // Create a temporary file with some front matter - tmpfile, err := os.CreateTemp("../content", "example.*.md") - println(tmpfile.Name()) - if err != nil { - log.Fatal(err) - } - defer os.Remove(tmpfile.Name()) // clean up - - text := `--- -name: "Test Title" -description: "Test Description" ---- - -# Test Content - -` - if _, err := tmpfile.Write([]byte(text)); err != nil { - log.Fatal(err) - } - if err := tmpfile.Close(); err != nil { - log.Fatal(err) - } - - // Get the directory entry for the temporary file - dirEntry, err := os.ReadDir(filepath.Dir(tmpfile.Name())) - if err != nil { - log.Fatal(err) - } - - var tmpFileEntry fs.DirEntry - for _, entry := range dirEntry { - if entry.Name() == filepath.Base(tmpfile.Name()) { - tmpFileEntry = entry - break - } - } - - // Now we can test ExtractFrontMatter - frontMatter, err := lib.ExtractFrontMatter(tmpFileEntry, "../content/") - assert.NoError(t, err) - assert.Equal(t, "Test Title", frontMatter.Name) - assert.Equal(t, "Test Description", frontMatter.Description) -} diff --git a/lib/pubsub.go b/lib/pubsub.go new file mode 100644 index 0000000..d91efde --- /dev/null +++ b/lib/pubsub.go @@ -0,0 +1,16 @@ +package lib + +import "context" + +type Message struct { + Payload string +} + +type PubSubMessage interface { + ReceiveMessage(ctx context.Context) (*Message, error) +} + +type PubSub interface { + SubscribeToChannel(channel string) (PubSubMessage, error) + PublishToChannel(channel string, message string) error +} diff --git a/lib/redis.go b/lib/redis.go index 81c9057..b32bda3 100644 --- a/lib/redis.go +++ b/lib/redis.go @@ -9,11 +9,18 @@ import ( "github.com/redis/go-redis/v9" ) -var ctx = context.Background() - var RedisClient *redis.Client -func NewClient() *redis.Client { +type RedisPubSubMessage struct { + pubsub *redis.PubSub +} + +// RedisPubSub is a Redis implementation of the PubSub interface. +type RedisPubSub struct { + Client *redis.Client +} + +func NewRedisClient() *redis.Client { if RedisClient != nil { return RedisClient } @@ -32,23 +39,29 @@ func NewClient() *redis.Client { return RedisClient } -func Publish(client *redis.Client, channel string, message string) error { - if client == nil { - client = NewClient() - } - - return client.Publish(ctx, channel, message).Err() -} - -func Subscribe(client *redis.Client, channel string) (*redis.PubSub, string) { - if client == nil { - client = NewClient() - } - - pubsub := client.Subscribe(ctx, channel) - _, err := pubsub.Receive(ctx) +func (m *RedisPubSubMessage) ReceiveMessage(ctx context.Context) (*Message, error) { + msg, err := m.pubsub.ReceiveMessage(ctx) if err != nil { - log.Fatalf("Error receiving subscription: %v", err) + return nil, err } - return pubsub, channel + + return &Message{Payload: msg.Payload}, nil +} + +func (ps *RedisPubSub) SubscribeToChannel(channel string) (PubSubMessage, error) { + pubsub := ps.Client.Subscribe(context.Background(), channel) + _, err := pubsub.Receive(context.Background()) + if err != nil { + return nil, err + } + + return &RedisPubSubMessage{pubsub: pubsub}, nil +} + +func (r *RedisPubSub) PublishToChannel(channel string, message string) error { + err := r.Client.Publish(context.Background(), channel, message).Err() + if err != nil { + return err + } + return nil } diff --git a/lib/redis_test.go b/lib/redis_test.go deleted file mode 100644 index 7643613..0000000 --- a/lib/redis_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package lib_test - -import ( - "testing" - - "github.com/go-redis/redismock/v9" - "github.com/stretchr/testify/assert" - "goth.stack/lib" -) - -func TestPublish(t *testing.T) { - db, mock := redismock.NewClientMock() - mock.ExpectPublish("mychannel", "mymessage").SetVal(1) - - err := lib.Publish(db, "mychannel", "mymessage") - assert.NoError(t, err) - assert.NoError(t, mock.ExpectationsWereMet()) -} - -// Then you can check the channel name in your test -func TestSubscribe(t *testing.T) { - db, _ := redismock.NewClientMock() - - pubsub, channel := lib.Subscribe(db, "mychannel") - assert.NotNil(t, pubsub) - assert.Equal(t, "mychannel", channel) -} diff --git a/lib/sse.go b/lib/sse.go index 585c9f8..544ba66 100644 --- a/lib/sse.go +++ b/lib/sse.go @@ -1,6 +1,15 @@ package lib -import "sync" +import ( + "context" + "fmt" + "log" + "net/http" + "sync" + "time" + + "github.com/labstack/echo/v4" +) type SSEServerType struct { clients map[string]map[chan string]bool @@ -8,6 +17,7 @@ type SSEServerType struct { } var SSEServer *SSEServerType +var mutex = &sync.Mutex{} func init() { SSEServer = &SSEServerType{ @@ -48,14 +58,20 @@ func (s *SSEServerType) ClientCount(channel string) int { return len(s.clients[channel]) } -func SendSSE(channel string, message string) error { +func SendSSE(ctx context.Context, messageBroker PubSub, channel string, message string) error { // Create a channel to receive an error from the goroutine errCh := make(chan error, 1) // Use a goroutine to send the message asynchronously go func() { - err := Publish(RedisClient, channel, message) - errCh <- err // Send the error to the channel + select { + case <-ctx.Done(): + // The client has disconnected, so return an error + errCh <- ctx.Err() + default: + err := messageBroker.PublishToChannel(channel, message) + errCh <- err // Send the error to the channel + } }() // Wait for the goroutine to finish and check for errors @@ -66,3 +82,63 @@ func SendSSE(channel string, message string) error { return nil } + +func SetSSEHeaders(c echo.Context) { + c.Response().Header().Set(echo.HeaderContentType, "text/event-stream") + c.Response().Header().Set(echo.HeaderConnection, "keep-alive") + c.Response().Header().Set(echo.HeaderCacheControl, "no-cache") +} + +func CreateTickerAndKeepAlive(c echo.Context, duration time.Duration) *time.Ticker { + ticker := time.NewTicker(duration) + go func() { + for range ticker.C { + if _, err := c.Response().Write([]byte(": keep-alive\n\n")); err != nil { + log.Printf("Failed to write keep-alive: %v", err) + } + c.Response().Flush() + } + }() + return ticker +} + +func HandleIncomingMessages(c echo.Context, pubsub PubSubMessage, client chan string) { + for { + select { + case <-c.Request().Context().Done(): + // The client has disconnected. Stop trying to send messages. + return + default: + // The client is still connected. Continue processing messages. + msg, err := pubsub.ReceiveMessage(c.Request().Context()) + if err != nil { + log.Printf("Failed to receive message: %v", err) + continue + } + + data := fmt.Sprintf("data: %s\n\n", msg.Payload) + + mutex.Lock() + _, err = c.Response().Write([]byte(data)) + mutex.Unlock() + + if err != nil { + log.Printf("Failed to write message: %v", err) + return // Stop processing if an error occurs + } + + // Check if the ResponseWriter is nil before trying to flush it + if c.Response().Writer != nil { + // Check if the ResponseWriter implements http.Flusher before calling Flush + flusher, ok := c.Response().Writer.(http.Flusher) + if ok { + flusher.Flush() + } else { + log.Println("Failed to flush: ResponseWriter does not implement http.Flusher") + } + } else { + log.Println("Failed to flush: ResponseWriter is nil") + } + } + } +} diff --git a/main.go b/main.go index f0533d5..69fdeeb 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,9 @@ package main import ( + "context" + "flag" + "fmt" "log" "net/http" @@ -9,6 +12,7 @@ import ( "github.com/labstack/echo/v4/middleware" "goth.stack/api" + "goth.stack/lib" "goth.stack/pages" ) @@ -16,6 +20,24 @@ func main() { // Load environment variables godotenv.Load(".env") + // Initialize Redis client + lib.RedisClient = lib.NewRedisClient() + + // Test Redis connection + _, err := lib.RedisClient.Ping(context.Background()).Result() + + // Initialize pubsub + var pubSub lib.PubSub + if err != nil { + log.Printf("Failed to connect to Redis: %v", err) + log.Println("Falling back to LocalPubSub") + pubSub = &lib.LocalPubSub{} + } else { + pubSub = &lib.RedisPubSub{ + Client: lib.RedisClient, + } + } + // Initialize Echo router e := echo.New() @@ -23,7 +45,6 @@ func main() { e.Use(middleware.Logger()) e.Use(middleware.Recover()) e.Pre(middleware.RemoveTrailingSlash()) - e.Use(middleware.Logger()) e.Use(middleware.RequestID()) e.Use(middleware.Secure()) e.Use(middleware.GzipWithConfig(middleware.GzipConfig{ @@ -43,14 +64,25 @@ func main() { // API Routes: apiGroup := e.Group("/api") apiGroup.GET("/ping", api.Ping) - apiGroup.GET("/sse", api.SSE) - apiGroup.POST("/sendsse", api.SSEDemoSend) + + apiGroup.GET("/sse", func(c echo.Context) error { + return api.SSE(c, pubSub) + }) + + apiGroup.POST("/sendsse", func(c echo.Context) error { + return api.SSEDemoSend(c, pubSub) + }) + + // Parse command-line arguments for IP and port + ip := flag.String("ip", "", "IP address to bind the server to") + port := flag.String("port", "3000", "Port to bind the server to") + flag.Parse() // Start server with HTTP/2 support s := &http.Server{ - Addr: ":3000", + Addr: fmt.Sprintf("%s:%s", *ip, *port), Handler: e, } e.Logger.Fatal(e.StartServer(s)) - log.Println("Server started on port 3000") + log.Println("Server started on port", *port) }