From 6122cafd441402f1376818275a71dda7f4f717a8 Mon Sep 17 00:00:00 2001 From: atridadl Date: Wed, 10 Jan 2024 00:28:31 -0700 Subject: [PATCH] Simplified --- lib/cooldowns.go | 40 ++++++++++++++++++++++++++++++------- lib/helpers.go | 4 ++-- lib/openai.go | 12 +++++------ main.go | 52 +++++------------------------------------------- 4 files changed, 45 insertions(+), 63 deletions(-) diff --git a/lib/cooldowns.go b/lib/cooldowns.go index 0adf1e8..e2cc173 100644 --- a/lib/cooldowns.go +++ b/lib/cooldowns.go @@ -5,6 +5,11 @@ import ( "time" ) +var ( + mu sync.Mutex + instance *CooldownManager +) + type CooldownManager struct { cooldowns map[string]time.Time mu sync.Mutex @@ -16,26 +21,47 @@ func NewCooldownManager() *CooldownManager { } } -func (m *CooldownManager) StartCooldown(key string, duration time.Duration) { - m.mu.Lock() - defer m.mu.Unlock() +func GetInstance() *CooldownManager { + mu.Lock() + defer mu.Unlock() - m.cooldowns[key] = time.Now().Add(duration) + if instance == nil { + instance = &CooldownManager{ + cooldowns: make(map[string]time.Time), + } + } + + return instance } -func (m *CooldownManager) IsOnCooldown(key string) bool { +func (m *CooldownManager) StartCooldown(userID string, key string, duration time.Duration) { m.mu.Lock() defer m.mu.Unlock() - cooldownEnd, exists := m.cooldowns[key] + m.cooldowns[userID+":"+key] = time.Now().Add(duration) +} + +func (m *CooldownManager) IsOnCooldown(userID string, key string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + cooldownEnd, exists := m.cooldowns[userID+":"+key] if !exists { return false } if time.Now().After(cooldownEnd) { - delete(m.cooldowns, key) + delete(m.cooldowns, userID+":"+key) return false } return true } + +func CancelCooldown(userID string, key string) { + manager := GetInstance() + manager.mu.Lock() + defer manager.mu.Unlock() + + delete(manager.cooldowns, userID+":"+key) +} diff --git a/lib/helpers.go b/lib/helpers.go index f7e83a7..74123a5 100644 --- a/lib/helpers.go +++ b/lib/helpers.go @@ -78,10 +78,10 @@ func CooldownHandler(event discord.InteractionEvent, key string, duration time.D } } - if manager.IsOnCooldown(key) { + if manager.IsOnCooldown(user.ID().String(), key) { return false } - manager.StartCooldown(key, duration) + manager.StartCooldown(user.ID().String(), key, duration) return true } diff --git a/lib/openai.go b/lib/openai.go index 0c20d3d..a0f1ae6 100644 --- a/lib/openai.go +++ b/lib/openai.go @@ -3,9 +3,9 @@ package lib import ( "bytes" "context" + "errors" "fmt" "io" - "log" "net/http" "os" @@ -47,16 +47,15 @@ func OpenAITextGeneration(prompt string) (string, error) { return resp.Choices[0].Message.Content, nil } -func OpenAIImageGeneration(prompt string) (*bytes.Buffer, error) { - // Send the generation request to DALL·E 3 +func OpenAIImageGeneration(prompt string) (imageFile *bytes.Buffer, err error) { resp, err := client.CreateImage(context.Background(), openai.ImageRequest{ Prompt: prompt, Model: "dall-e-3", Size: "1024x1024", }) + if err != nil { - log.Printf("Image creation error: %v\n", err) - return nil, fmt.Errorf("failed to generate image") + return nil, errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements") } imageRes, err := http.Get(resp.Data[0].URL) @@ -73,6 +72,5 @@ func OpenAIImageGeneration(prompt string) (*bytes.Buffer, error) { return nil, err } - imageFile := bytes.NewBuffer(imageBytes) - return imageFile, nil + return bytes.NewBuffer(imageBytes), nil } diff --git a/main.go b/main.go index 0f3ad47..27ddee2 100644 --- a/main.go +++ b/main.go @@ -40,18 +40,7 @@ var commands = []api.CreateCommandData{ }, { Name: "pic", - Description: "Generate an image using Stable Diffusion! Cooldown: 1 Minute.", - Options: []discord.CommandOption{ - &discord.StringOption{ - OptionName: "prompt", - Description: "The prompt for the image generation.", - Required: true, - }, - }, - }, - { - Name: "hdpic", - Description: "Generate an image using DALL·E 3! Cooldown: 10 Minutes.", + Description: "Generate an image! Cooldown: 1 Minute.", Options: []discord.CommandOption{ &discord.StringOption{ OptionName: "prompt", @@ -115,7 +104,6 @@ func newHandler(s *state.State) *handler { h.AddFunc("ping", h.cmdPing) h.AddFunc("ask", h.cmdAsk) h.AddFunc("pic", h.cmdPic) - h.AddFunc("hdpic", h.cmdHDPic) h.AddFunc("hs", h.cmdHS) return h @@ -142,6 +130,7 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In } if err := data.Options.Unmarshal(&options); err != nil { + lib.CancelCooldown(data.Event.User.ID.String(), "ask") return errorResponse(err) } @@ -149,6 +138,7 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) + lib.CancelCooldown(data.Event.User.ID.String(), "ask") return &api.InteractionResponseData{ Content: option.NewNullableString("ChatCompletion Error!"), AllowedMentions: &api.AllowedMentions{}, @@ -189,46 +179,14 @@ func (h *handler) cmdPic(ctx context.Context, data cmdroute.CommandData) *api.In } if err := data.Options.Unmarshal(&options); err != nil { + lib.CancelCooldown(data.Event.User.ID.String(), "pic") return errorResponse(err) } imageFile, err := lib.ReplicateImageGeneration(options.Prompt) if err != nil { - return errorResponse(err) - } - - file := sendpart.File{ - Name: "himbot_response.png", - Reader: imageFile, - } - - return &api.InteractionResponseData{ - Content: option.NewNullableString("Prompt: " + options.Prompt), - Files: []sendpart.File{file}, - } -} - -func (h *handler) cmdHDPic(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { - // Cooldown Logic - allowed := lib.CooldownHandler(*data.Event, "hdPic", time.Minute*10) - - if !allowed { - return errorResponse(errors.New("please wait for the cooldown")) - } - - // Command Logic - var options struct { - Prompt string `discord:"prompt"` - } - - if err := data.Options.Unmarshal(&options); err != nil { - return errorResponse(err) - } - - imageFile, err := lib.OpenAIImageGeneration(options.Prompt) - - if err != nil { + lib.CancelCooldown(data.Event.User.ID.String(), "pic") return errorResponse(err) }