Simplified

This commit is contained in:
Atridad Lahiji 2024-01-10 00:28:31 -07:00
parent bc3b61c148
commit 6122cafd44
No known key found for this signature in database
4 changed files with 45 additions and 63 deletions

View file

@ -5,6 +5,11 @@ import (
"time" "time"
) )
var (
mu sync.Mutex
instance *CooldownManager
)
type CooldownManager struct { type CooldownManager struct {
cooldowns map[string]time.Time cooldowns map[string]time.Time
mu sync.Mutex mu sync.Mutex
@ -16,26 +21,47 @@ func NewCooldownManager() *CooldownManager {
} }
} }
func (m *CooldownManager) StartCooldown(key string, duration time.Duration) { func GetInstance() *CooldownManager {
m.mu.Lock() mu.Lock()
defer m.mu.Unlock() defer mu.Unlock()
m.cooldowns[key] = time.Now().Add(duration) if instance == nil {
instance = &CooldownManager{
cooldowns: make(map[string]time.Time),
}
}
return instance
} }
func (m *CooldownManager) IsOnCooldown(key string) bool { func (m *CooldownManager) StartCooldown(userID string, key string, duration time.Duration) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
cooldownEnd, exists := m.cooldowns[key] m.cooldowns[userID+":"+key] = time.Now().Add(duration)
}
func (m *CooldownManager) IsOnCooldown(userID string, key string) bool {
m.mu.Lock()
defer m.mu.Unlock()
cooldownEnd, exists := m.cooldowns[userID+":"+key]
if !exists { if !exists {
return false return false
} }
if time.Now().After(cooldownEnd) { if time.Now().After(cooldownEnd) {
delete(m.cooldowns, key) delete(m.cooldowns, userID+":"+key)
return false return false
} }
return true return true
} }
func CancelCooldown(userID string, key string) {
manager := GetInstance()
manager.mu.Lock()
defer manager.mu.Unlock()
delete(manager.cooldowns, userID+":"+key)
}

View file

@ -78,10 +78,10 @@ func CooldownHandler(event discord.InteractionEvent, key string, duration time.D
} }
} }
if manager.IsOnCooldown(key) { if manager.IsOnCooldown(user.ID().String(), key) {
return false return false
} }
manager.StartCooldown(key, duration) manager.StartCooldown(user.ID().String(), key, duration)
return true return true
} }

View file

@ -3,9 +3,9 @@ package lib
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"os" "os"
@ -47,16 +47,15 @@ func OpenAITextGeneration(prompt string) (string, error) {
return resp.Choices[0].Message.Content, nil return resp.Choices[0].Message.Content, nil
} }
func OpenAIImageGeneration(prompt string) (*bytes.Buffer, error) { func OpenAIImageGeneration(prompt string) (imageFile *bytes.Buffer, err error) {
// Send the generation request to DALL·E 3
resp, err := client.CreateImage(context.Background(), openai.ImageRequest{ resp, err := client.CreateImage(context.Background(), openai.ImageRequest{
Prompt: prompt, Prompt: prompt,
Model: "dall-e-3", Model: "dall-e-3",
Size: "1024x1024", Size: "1024x1024",
}) })
if err != nil { if err != nil {
log.Printf("Image creation error: %v\n", err) return nil, errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements")
return nil, fmt.Errorf("failed to generate image")
} }
imageRes, err := http.Get(resp.Data[0].URL) imageRes, err := http.Get(resp.Data[0].URL)
@ -73,6 +72,5 @@ func OpenAIImageGeneration(prompt string) (*bytes.Buffer, error) {
return nil, err return nil, err
} }
imageFile := bytes.NewBuffer(imageBytes) return bytes.NewBuffer(imageBytes), nil
return imageFile, nil
} }

52
main.go
View file

@ -40,18 +40,7 @@ var commands = []api.CreateCommandData{
}, },
{ {
Name: "pic", Name: "pic",
Description: "Generate an image using Stable Diffusion! Cooldown: 1 Minute.", Description: "Generate an image! Cooldown: 1 Minute.",
Options: []discord.CommandOption{
&discord.StringOption{
OptionName: "prompt",
Description: "The prompt for the image generation.",
Required: true,
},
},
},
{
Name: "hdpic",
Description: "Generate an image using DALL·E 3! Cooldown: 10 Minutes.",
Options: []discord.CommandOption{ Options: []discord.CommandOption{
&discord.StringOption{ &discord.StringOption{
OptionName: "prompt", OptionName: "prompt",
@ -115,7 +104,6 @@ func newHandler(s *state.State) *handler {
h.AddFunc("ping", h.cmdPing) h.AddFunc("ping", h.cmdPing)
h.AddFunc("ask", h.cmdAsk) h.AddFunc("ask", h.cmdAsk)
h.AddFunc("pic", h.cmdPic) h.AddFunc("pic", h.cmdPic)
h.AddFunc("hdpic", h.cmdHDPic)
h.AddFunc("hs", h.cmdHS) h.AddFunc("hs", h.cmdHS)
return h return h
@ -142,6 +130,7 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In
} }
if err := data.Options.Unmarshal(&options); err != nil { if err := data.Options.Unmarshal(&options); err != nil {
lib.CancelCooldown(data.Event.User.ID.String(), "ask")
return errorResponse(err) return errorResponse(err)
} }
@ -149,6 +138,7 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In
if err != nil { if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err) fmt.Printf("ChatCompletion error: %v\n", err)
lib.CancelCooldown(data.Event.User.ID.String(), "ask")
return &api.InteractionResponseData{ return &api.InteractionResponseData{
Content: option.NewNullableString("ChatCompletion Error!"), Content: option.NewNullableString("ChatCompletion Error!"),
AllowedMentions: &api.AllowedMentions{}, AllowedMentions: &api.AllowedMentions{},
@ -189,46 +179,14 @@ func (h *handler) cmdPic(ctx context.Context, data cmdroute.CommandData) *api.In
} }
if err := data.Options.Unmarshal(&options); err != nil { if err := data.Options.Unmarshal(&options); err != nil {
lib.CancelCooldown(data.Event.User.ID.String(), "pic")
return errorResponse(err) return errorResponse(err)
} }
imageFile, err := lib.ReplicateImageGeneration(options.Prompt) imageFile, err := lib.ReplicateImageGeneration(options.Prompt)
if err != nil { if err != nil {
return errorResponse(err) lib.CancelCooldown(data.Event.User.ID.String(), "pic")
}
file := sendpart.File{
Name: "himbot_response.png",
Reader: imageFile,
}
return &api.InteractionResponseData{
Content: option.NewNullableString("Prompt: " + options.Prompt),
Files: []sendpart.File{file},
}
}
func (h *handler) cmdHDPic(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData {
// Cooldown Logic
allowed := lib.CooldownHandler(*data.Event, "hdPic", time.Minute*10)
if !allowed {
return errorResponse(errors.New("please wait for the cooldown"))
}
// Command Logic
var options struct {
Prompt string `discord:"prompt"`
}
if err := data.Options.Unmarshal(&options); err != nil {
return errorResponse(err)
}
imageFile, err := lib.OpenAIImageGeneration(options.Prompt)
if err != nil {
return errorResponse(err) return errorResponse(err)
} }