Simplified
This commit is contained in:
parent
bc3b61c148
commit
6122cafd44
4 changed files with 45 additions and 63 deletions
|
@ -5,6 +5,11 @@ import (
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
instance *CooldownManager
|
||||||
|
)
|
||||||
|
|
||||||
type CooldownManager struct {
|
type CooldownManager struct {
|
||||||
cooldowns map[string]time.Time
|
cooldowns map[string]time.Time
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
@ -16,26 +21,47 @@ func NewCooldownManager() *CooldownManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *CooldownManager) StartCooldown(key string, duration time.Duration) {
|
func GetInstance() *CooldownManager {
|
||||||
m.mu.Lock()
|
mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer mu.Unlock()
|
||||||
|
|
||||||
m.cooldowns[key] = time.Now().Add(duration)
|
if instance == nil {
|
||||||
|
instance = &CooldownManager{
|
||||||
|
cooldowns: make(map[string]time.Time),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return instance
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *CooldownManager) IsOnCooldown(key string) bool {
|
func (m *CooldownManager) StartCooldown(userID string, key string, duration time.Duration) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
cooldownEnd, exists := m.cooldowns[key]
|
m.cooldowns[userID+":"+key] = time.Now().Add(duration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *CooldownManager) IsOnCooldown(userID string, key string) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
cooldownEnd, exists := m.cooldowns[userID+":"+key]
|
||||||
if !exists {
|
if !exists {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().After(cooldownEnd) {
|
if time.Now().After(cooldownEnd) {
|
||||||
delete(m.cooldowns, key)
|
delete(m.cooldowns, userID+":"+key)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CancelCooldown(userID string, key string) {
|
||||||
|
manager := GetInstance()
|
||||||
|
manager.mu.Lock()
|
||||||
|
defer manager.mu.Unlock()
|
||||||
|
|
||||||
|
delete(manager.cooldowns, userID+":"+key)
|
||||||
|
}
|
||||||
|
|
|
@ -78,10 +78,10 @@ func CooldownHandler(event discord.InteractionEvent, key string, duration time.D
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if manager.IsOnCooldown(key) {
|
if manager.IsOnCooldown(user.ID().String(), key) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.StartCooldown(key, duration)
|
manager.StartCooldown(user.ID().String(), key, duration)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,9 +3,9 @@ package lib
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
@ -47,16 +47,15 @@ func OpenAITextGeneration(prompt string) (string, error) {
|
||||||
return resp.Choices[0].Message.Content, nil
|
return resp.Choices[0].Message.Content, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenAIImageGeneration(prompt string) (*bytes.Buffer, error) {
|
func OpenAIImageGeneration(prompt string) (imageFile *bytes.Buffer, err error) {
|
||||||
// Send the generation request to DALL·E 3
|
|
||||||
resp, err := client.CreateImage(context.Background(), openai.ImageRequest{
|
resp, err := client.CreateImage(context.Background(), openai.ImageRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Model: "dall-e-3",
|
Model: "dall-e-3",
|
||||||
Size: "1024x1024",
|
Size: "1024x1024",
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Image creation error: %v\n", err)
|
return nil, errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements")
|
||||||
return nil, fmt.Errorf("failed to generate image")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
imageRes, err := http.Get(resp.Data[0].URL)
|
imageRes, err := http.Get(resp.Data[0].URL)
|
||||||
|
@ -73,6 +72,5 @@ func OpenAIImageGeneration(prompt string) (*bytes.Buffer, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
imageFile := bytes.NewBuffer(imageBytes)
|
return bytes.NewBuffer(imageBytes), nil
|
||||||
return imageFile, nil
|
|
||||||
}
|
}
|
||||||
|
|
52
main.go
52
main.go
|
@ -40,18 +40,7 @@ var commands = []api.CreateCommandData{
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Name: "pic",
|
Name: "pic",
|
||||||
Description: "Generate an image using Stable Diffusion! Cooldown: 1 Minute.",
|
Description: "Generate an image! Cooldown: 1 Minute.",
|
||||||
Options: []discord.CommandOption{
|
|
||||||
&discord.StringOption{
|
|
||||||
OptionName: "prompt",
|
|
||||||
Description: "The prompt for the image generation.",
|
|
||||||
Required: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "hdpic",
|
|
||||||
Description: "Generate an image using DALL·E 3! Cooldown: 10 Minutes.",
|
|
||||||
Options: []discord.CommandOption{
|
Options: []discord.CommandOption{
|
||||||
&discord.StringOption{
|
&discord.StringOption{
|
||||||
OptionName: "prompt",
|
OptionName: "prompt",
|
||||||
|
@ -115,7 +104,6 @@ func newHandler(s *state.State) *handler {
|
||||||
h.AddFunc("ping", h.cmdPing)
|
h.AddFunc("ping", h.cmdPing)
|
||||||
h.AddFunc("ask", h.cmdAsk)
|
h.AddFunc("ask", h.cmdAsk)
|
||||||
h.AddFunc("pic", h.cmdPic)
|
h.AddFunc("pic", h.cmdPic)
|
||||||
h.AddFunc("hdpic", h.cmdHDPic)
|
|
||||||
h.AddFunc("hs", h.cmdHS)
|
h.AddFunc("hs", h.cmdHS)
|
||||||
|
|
||||||
return h
|
return h
|
||||||
|
@ -142,6 +130,7 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := data.Options.Unmarshal(&options); err != nil {
|
if err := data.Options.Unmarshal(&options); err != nil {
|
||||||
|
lib.CancelCooldown(data.Event.User.ID.String(), "ask")
|
||||||
return errorResponse(err)
|
return errorResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,6 +138,7 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("ChatCompletion error: %v\n", err)
|
fmt.Printf("ChatCompletion error: %v\n", err)
|
||||||
|
lib.CancelCooldown(data.Event.User.ID.String(), "ask")
|
||||||
return &api.InteractionResponseData{
|
return &api.InteractionResponseData{
|
||||||
Content: option.NewNullableString("ChatCompletion Error!"),
|
Content: option.NewNullableString("ChatCompletion Error!"),
|
||||||
AllowedMentions: &api.AllowedMentions{},
|
AllowedMentions: &api.AllowedMentions{},
|
||||||
|
@ -189,46 +179,14 @@ func (h *handler) cmdPic(ctx context.Context, data cmdroute.CommandData) *api.In
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := data.Options.Unmarshal(&options); err != nil {
|
if err := data.Options.Unmarshal(&options); err != nil {
|
||||||
|
lib.CancelCooldown(data.Event.User.ID.String(), "pic")
|
||||||
return errorResponse(err)
|
return errorResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
imageFile, err := lib.ReplicateImageGeneration(options.Prompt)
|
imageFile, err := lib.ReplicateImageGeneration(options.Prompt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorResponse(err)
|
lib.CancelCooldown(data.Event.User.ID.String(), "pic")
|
||||||
}
|
|
||||||
|
|
||||||
file := sendpart.File{
|
|
||||||
Name: "himbot_response.png",
|
|
||||||
Reader: imageFile,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &api.InteractionResponseData{
|
|
||||||
Content: option.NewNullableString("Prompt: " + options.Prompt),
|
|
||||||
Files: []sendpart.File{file},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handler) cmdHDPic(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData {
|
|
||||||
// Cooldown Logic
|
|
||||||
allowed := lib.CooldownHandler(*data.Event, "hdPic", time.Minute*10)
|
|
||||||
|
|
||||||
if !allowed {
|
|
||||||
return errorResponse(errors.New("please wait for the cooldown"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Command Logic
|
|
||||||
var options struct {
|
|
||||||
Prompt string `discord:"prompt"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := data.Options.Unmarshal(&options); err != nil {
|
|
||||||
return errorResponse(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
imageFile, err := lib.OpenAIImageGeneration(options.Prompt)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return errorResponse(err)
|
return errorResponse(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue