diff --git a/command/ask.go b/command/ask.go index dd9936b..b791a38 100644 --- a/command/ask.go +++ b/command/ask.go @@ -32,7 +32,7 @@ func Ask(ctx context.Context, data cmdroute.CommandData) *api.InteractionRespons return lib.ErrorResponse(err) } - respString, err := lib.ReplicateTextGeneration(options.Prompt) + respString, err := lib.OpenAITextGeneration(options.Prompt) if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) diff --git a/command/code.go b/command/code.go deleted file mode 100644 index f0d5d78..0000000 --- a/command/code.go +++ /dev/null @@ -1,64 +0,0 @@ -package command - -import ( - "bytes" - "context" - "errors" - "fmt" - "himbot/lib" - "time" - - "github.com/diamondburned/arikawa/v3/api" - "github.com/diamondburned/arikawa/v3/api/cmdroute" - "github.com/diamondburned/arikawa/v3/utils/json/option" - "github.com/diamondburned/arikawa/v3/utils/sendpart" -) - -func Code(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { - // Cooldown Logic - allowed, cooldownString := lib.CooldownHandler(*data.Event, "code", time.Minute) - - if !allowed { - return lib.ErrorResponse(errors.New(cooldownString)) - } - - // Command Logic - var options struct { - Prompt string `discord:"prompt"` - } - - if err := data.Options.Unmarshal(&options); err != nil { - lib.CancelCooldown(data.Event.User.ID.String(), "ask") - return lib.ErrorResponse(err) - } - - respString, err := lib.ReplicateCodeGeneration(options.Prompt) - - if err != nil { - fmt.Printf("ChatCompletion error: %v\n", err) - lib.CancelCooldown(data.Event.User.ID.String(), "ask") - return &api.InteractionResponseData{ - Content: option.NewNullableString("ChatCompletion Error!"), - AllowedMentions: &api.AllowedMentions{}, - } - } - - if len(respString) > 1800 { - textFile := bytes.NewBuffer([]byte(respString)) - - file := sendpart.File{ - Name: "himbot_response.md", - Reader: textFile, - } - - return &api.InteractionResponseData{ - Content: option.NewNullableString("Prompt: " + options.Prompt + "\n"), - AllowedMentions: &api.AllowedMentions{}, - Files: []sendpart.File{file}, - } - } - return &api.InteractionResponseData{ - Content: option.NewNullableString("Prompt: " + options.Prompt + "\n--------------------\n" + respString), - AllowedMentions: &api.AllowedMentions{}, - } -} diff --git a/command/pic.go b/command/pic.go index ee1d58f..16179cb 100644 --- a/command/pic.go +++ b/command/pic.go @@ -37,7 +37,7 @@ func Pic(ctx context.Context, data cmdroute.CommandData) *api.InteractionRespons // Concatenate clean username and timestamp to form filename filename := data.Event.Sender().Username + "_" + timestamp + ".jpg" - imageFile, err := lib.ReplicateImageGeneration(options.Prompt, filename) + imageFile, err := lib.OpenAIImageGeneration(options.Prompt, filename) if err != nil { lib.CancelCooldown(data.Event.User.ID.String(), "pic") diff --git a/go.mod b/go.mod index 80391e5..b8f6152 100644 --- a/go.mod +++ b/go.mod @@ -5,16 +5,17 @@ go 1.22.0 require github.com/diamondburned/arikawa/v3 v3.3.5 require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect golang.org/x/net v0.21.0 // indirect - golang.org/x/sync v0.6.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) require ( - github.com/aws/aws-sdk-go v1.50.24 + github.com/aws/aws-sdk-go v1.50.25 github.com/gorilla/schema v1.2.1 // indirect github.com/gorilla/websocket v1.5.1 // indirect github.com/joho/godotenv v1.5.1 - github.com/replicate/replicate-go v0.16.1 + github.com/sashabaranov/go-openai v1.20.0 golang.org/x/time v0.5.0 // indirect ) diff --git a/go.sum b/go.sum index e6f62a6..8f71464 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/aws/aws-sdk-go v1.50.24 h1:3o2Pg7mOoVL0jv54vWtuafoZqAeEXLhm1tltWA2GcEw= -github.com/aws/aws-sdk-go v1.50.24/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go v1.50.25 h1:vhiHtLYybv1Nhx3Kv18BBC6L0aPJHaG9aeEsr92W99c= +github.com/aws/aws-sdk-go v1.50.25/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -19,11 +19,9 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/replicate/replicate-go v0.16.1 h1:LbImDfB6ef0yEfWbBNQdnC5CoKmHxonoa/UUJ6YrFC8= -github.com/replicate/replicate-go v0.16.1/go.mod h1:otIrl1vDmyjNhTzmVmp/mQU3Wt1+3387gFNEsAZq0ig= +github.com/sashabaranov/go-openai v1.20.0 h1:r9WiwJY6Q2aPDhVyfOSKm83Gs04ogN1yaaBoQOnusS4= +github.com/sashabaranov/go-openai v1.20.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -40,8 +38,6 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -74,5 +70,3 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lib/openai.go b/lib/openai.go new file mode 100644 index 0000000..864333b --- /dev/null +++ b/lib/openai.go @@ -0,0 +1,108 @@ +package lib + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "os" + + "github.com/joho/godotenv" + "github.com/sashabaranov/go-openai" +) + +var PromptPrefix = "Your name is Himbot. You are a helpful but sarcastic and witty discord bot. Please respond with a natural response to the following prompt with that personality in mind:" + +func OpenAITextGeneration(prompt string) (string, error) { + godotenv.Load(".env") + apiKey := os.Getenv("OPENAI_API_KEY") + + client := openai.NewClient(apiKey) + + resp, err := client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT4Turbo1106, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: PromptPrefix + prompt, + }, + }, + }, + ) + + if err != nil { + fmt.Printf("Ask command error: %v\n", err) + return "", errors.New("there was an error generating the response based on this prompt... please reach out to @himbothyswaggins to fix this issue") + + } + + return resp.Choices[0].Message.Content, nil +} + +func OpenAIImageGeneration(prompt string, filename string) (*bytes.Buffer, error) { + godotenv.Load(".env") + apiKey := os.Getenv("OPENAI_API_KEY") + + client := openai.NewClient(apiKey) + + imageResponse, err := client.CreateImage( + context.Background(), + openai.ImageRequest{ + Model: openai.CreateImageModelDallE3, + Prompt: prompt, + Size: openai.CreateImageSize1024x1024, + Quality: openai.CreateImageQualityStandard, + ResponseFormat: openai.CreateImageResponseFormatURL, + N: 1, + }, + ) + + if err != nil { + fmt.Printf("Pic command error: %v\n", err) + return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue") + } + + imgUrl := imageResponse.Data[0].URL + + imageRes, imageGetErr := http.Get(imgUrl) + if imageGetErr != nil { + return nil, imageGetErr + } + + defer imageRes.Body.Close() + + imageBytes, imgReadErr := io.ReadAll(imageRes.Body) + if imgReadErr != nil { + return nil, imgReadErr + } + + // Save image to a temporary file + tmpfile, err := os.Create(filename) + + if err != nil { + log.Fatal(err) + } + + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write(imageBytes); err != nil { + log.Fatal(err) + } + if err := tmpfile.Close(); err != nil { + log.Fatal(err) + } + + // Upload the image to S3 + _, uploadErr := UploadToS3(tmpfile.Name()) + if uploadErr != nil { + log.Printf("Failed to upload image to S3: %v", uploadErr) + } + + imageFile := bytes.NewBuffer(imageBytes) + return imageFile, nil +} diff --git a/lib/replicate.go b/lib/replicate.go deleted file mode 100644 index 37e2af8..0000000 --- a/lib/replicate.go +++ /dev/null @@ -1,189 +0,0 @@ -package lib - -import ( - "bytes" - "context" - "errors" - "io" - "log" - "net/http" - "os" - "strings" - - "github.com/replicate/replicate-go" -) - -var PromptPrefix = "Ready for a dose of sarcasm and wit? Himbot, your Discord assistant, is up for the challenge. Hit it with the prompt:" - -func ReplicateTextGeneration(prompt string) (string, error) { - client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) - if clientError != nil { - return "", clientError - } - - input := replicate.PredictionInput{ - "prompt": PromptPrefix + prompt, - "max_new_tokens": 4096, - "prompt_template": "[INST] {prompt} [/INST]", - } - webhook := replicate.Webhook{ - URL: "https://example.com/webhook", - Events: []replicate.WebhookEventType{"start", "completed"}, - } - - prediction, predictionError := client.Run(context.Background(), "mistralai/mixtral-8x7b-instruct-v0.1:cf18decbf51c27fed6bbdc3492312c1c903222a56e3fe9ca02d6cbe5198afc10", input, &webhook) - - if predictionError != nil { - return "", predictionError - } - - if prediction == nil { - return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue") - } - - test, ok := prediction.([]interface{}) - - if !ok { - return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue") - } - - strs := make([]string, len(test)) - for i, v := range test { - str, ok := v.(string) - if !ok { - return "", errors.New("element is not a string") - } - strs[i] = str - } - - result := strings.Join(strs, "") - - return result, nil -} - -func ReplicateCodeGeneration(prompt string) (string, error) { - client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) - if clientError != nil { - return "", clientError - } - - input := replicate.PredictionInput{ - "prompt": PromptPrefix + prompt, - "max_new_tokens": 4096, - } - webhook := replicate.Webhook{ - URL: "https://example.com/webhook", - Events: []replicate.WebhookEventType{"start", "completed"}, - } - - prediction, predictionError := client.Run(context.Background(), "meta/codellama-70b-instruct:a279116fe47a0f65701a8817188601e2fe8f4b9e04a518789655ea7b995851bf", input, &webhook) - - if predictionError != nil { - return "", predictionError - } - - if prediction == nil { - return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue") - } - - test, ok := prediction.([]interface{}) - - if !ok { - return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue") - } - - strs := make([]string, len(test)) - for i, v := range test { - str, ok := v.(string) - if !ok { - return "", errors.New("element is not a string") - } - strs[i] = str - } - - result := strings.Join(strs, "") - - return result, nil -} - -func ReplicateImageGeneration(prompt string, filename string) (*bytes.Buffer, error) { - client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) - if clientError != nil { - return nil, clientError - } - - input := replicate.PredictionInput{ - "width": 1024, - "height": 1024, - "prompt": prompt, - "scheduler": "K_EULER", - "num_outputs": 1, - "guidance_scale": 0, - "negative_prompt": "worst quality, low quality", - "num_inference_steps": 4, - "disable_safety_checker": true, - } - webhook := replicate.Webhook{ - URL: "https://example.com/webhook", - Events: []replicate.WebhookEventType{"start", "completed"}, - } - - prediction, predictionError := client.Run(context.Background(), "lucataco/sdxl-lightning-4step:727e49a643e999d602a896c774a0658ffefea21465756a6ce24b7ea4165eba6a", input, &webhook) - - if predictionError != nil { - return nil, predictionError - } - - if prediction == nil { - return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue") - } - - test, ok := prediction.([]interface{}) - - if !ok { - return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue") - } - - imgUrl, ok := test[0].(string) - - if !ok { - return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue") - } - - imageRes, imageGetErr := http.Get(imgUrl) - if imageGetErr != nil { - return nil, imageGetErr - } - - defer imageRes.Body.Close() - - imageBytes, imgReadErr := io.ReadAll(imageRes.Body) - if imgReadErr != nil { - return nil, imgReadErr - } - - // Save image to a temporary file - tmpfile, err := os.Create(filename) - - if err != nil { - log.Fatal(err) - } - - defer os.Remove(tmpfile.Name()) - - if _, err := tmpfile.Write(imageBytes); err != nil { - log.Fatal(err) - } - if err := tmpfile.Close(); err != nil { - log.Fatal(err) - } - - // Upload the image to S3 - _, uploadErr := UploadToS3(tmpfile.Name()) - if uploadErr != nil { - log.Printf("Failed to upload image to S3: %v", uploadErr) - } - - imageFile := bytes.NewBuffer(imageBytes) - return imageFile, nil -} diff --git a/main.go b/main.go index 0fe451e..c44f77a 100644 --- a/main.go +++ b/main.go @@ -31,17 +31,6 @@ var commands = []api.CreateCommandData{ }, }, }, - { - Name: "code", - Description: "Ask Himbot programming questions! Cooldown: 2 Minutes.", - Options: []discord.CommandOption{ - &discord.StringOption{ - OptionName: "prompt", - Description: "The prompt to send to Himbot.", - Required: true, - }, - }, - }, { Name: "pic", Description: "Generate an image! Cooldown: 1 Minute.", @@ -107,7 +96,6 @@ func newHandler(s *state.State) *handler { h.Use(cmdroute.Deferrable(s, cmdroute.DeferOpts{})) h.AddFunc("ping", command.Ping) h.AddFunc("ask", command.Ask) - h.AddFunc("code", command.Code) h.AddFunc("pic", command.Pic) h.AddFunc("hs", command.HS)