Back to OpenAI for cost LOL

This commit is contained in:
Atridad Lahiji 2024-02-23 13:52:15 -07:00
parent 02686b9087
commit a9f2f42987
No known key found for this signature in database
8 changed files with 118 additions and 280 deletions

View file

@ -32,7 +32,7 @@ func Ask(ctx context.Context, data cmdroute.CommandData) *api.InteractionRespons
return lib.ErrorResponse(err)
}
respString, err := lib.ReplicateTextGeneration(options.Prompt)
respString, err := lib.OpenAITextGeneration(options.Prompt)
if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)

View file

@ -1,64 +0,0 @@
package command
import (
"bytes"
"context"
"errors"
"fmt"
"himbot/lib"
"time"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/api/cmdroute"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/sendpart"
)
func Code(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData {
// Cooldown Logic
allowed, cooldownString := lib.CooldownHandler(*data.Event, "code", time.Minute)
if !allowed {
return lib.ErrorResponse(errors.New(cooldownString))
}
// Command Logic
var options struct {
Prompt string `discord:"prompt"`
}
if err := data.Options.Unmarshal(&options); err != nil {
lib.CancelCooldown(data.Event.User.ID.String(), "ask")
return lib.ErrorResponse(err)
}
respString, err := lib.ReplicateCodeGeneration(options.Prompt)
if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
lib.CancelCooldown(data.Event.User.ID.String(), "ask")
return &api.InteractionResponseData{
Content: option.NewNullableString("ChatCompletion Error!"),
AllowedMentions: &api.AllowedMentions{},
}
}
if len(respString) > 1800 {
textFile := bytes.NewBuffer([]byte(respString))
file := sendpart.File{
Name: "himbot_response.md",
Reader: textFile,
}
return &api.InteractionResponseData{
Content: option.NewNullableString("Prompt: " + options.Prompt + "\n"),
AllowedMentions: &api.AllowedMentions{},
Files: []sendpart.File{file},
}
}
return &api.InteractionResponseData{
Content: option.NewNullableString("Prompt: " + options.Prompt + "\n--------------------\n" + respString),
AllowedMentions: &api.AllowedMentions{},
}
}

View file

@ -37,7 +37,7 @@ func Pic(ctx context.Context, data cmdroute.CommandData) *api.InteractionRespons
// Concatenate clean username and timestamp to form filename
filename := data.Event.Sender().Username + "_" + timestamp + ".jpg"
imageFile, err := lib.ReplicateImageGeneration(options.Prompt, filename)
imageFile, err := lib.OpenAIImageGeneration(options.Prompt, filename)
if err != nil {
lib.CancelCooldown(data.Event.User.ID.String(), "pic")

7
go.mod
View file

@ -5,16 +5,17 @@ go 1.22.0
require github.com/diamondburned/arikawa/v3 v3.3.5
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.6.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)
require (
github.com/aws/aws-sdk-go v1.50.24
github.com/aws/aws-sdk-go v1.50.25
github.com/gorilla/schema v1.2.1 // indirect
github.com/gorilla/websocket v1.5.1 // indirect
github.com/joho/godotenv v1.5.1
github.com/replicate/replicate-go v0.16.1
github.com/sashabaranov/go-openai v1.20.0
golang.org/x/time v0.5.0 // indirect
)

14
go.sum
View file

@ -1,5 +1,5 @@
github.com/aws/aws-sdk-go v1.50.24 h1:3o2Pg7mOoVL0jv54vWtuafoZqAeEXLhm1tltWA2GcEw=
github.com/aws/aws-sdk-go v1.50.24/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/aws/aws-sdk-go v1.50.25 h1:vhiHtLYybv1Nhx3Kv18BBC6L0aPJHaG9aeEsr92W99c=
github.com/aws/aws-sdk-go v1.50.25/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -19,11 +19,9 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/replicate/replicate-go v0.16.1 h1:LbImDfB6ef0yEfWbBNQdnC5CoKmHxonoa/UUJ6YrFC8=
github.com/replicate/replicate-go v0.16.1/go.mod h1:otIrl1vDmyjNhTzmVmp/mQU3Wt1+3387gFNEsAZq0ig=
github.com/sashabaranov/go-openai v1.20.0 h1:r9WiwJY6Q2aPDhVyfOSKm83Gs04ogN1yaaBoQOnusS4=
github.com/sashabaranov/go-openai v1.20.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
@ -40,8 +38,6 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -74,5 +70,3 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

108
lib/openai.go Normal file
View file

@ -0,0 +1,108 @@
package lib
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"github.com/joho/godotenv"
"github.com/sashabaranov/go-openai"
)
var PromptPrefix = "Your name is Himbot. You are a helpful but sarcastic and witty discord bot. Please respond with a natural response to the following prompt with that personality in mind:"
func OpenAITextGeneration(prompt string) (string, error) {
godotenv.Load(".env")
apiKey := os.Getenv("OPENAI_API_KEY")
client := openai.NewClient(apiKey)
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT4Turbo1106,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: PromptPrefix + prompt,
},
},
},
)
if err != nil {
fmt.Printf("Ask command error: %v\n", err)
return "", errors.New("there was an error generating the response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
return resp.Choices[0].Message.Content, nil
}
func OpenAIImageGeneration(prompt string, filename string) (*bytes.Buffer, error) {
godotenv.Load(".env")
apiKey := os.Getenv("OPENAI_API_KEY")
client := openai.NewClient(apiKey)
imageResponse, err := client.CreateImage(
context.Background(),
openai.ImageRequest{
Model: openai.CreateImageModelDallE3,
Prompt: prompt,
Size: openai.CreateImageSize1024x1024,
Quality: openai.CreateImageQualityStandard,
ResponseFormat: openai.CreateImageResponseFormatURL,
N: 1,
},
)
if err != nil {
fmt.Printf("Pic command error: %v\n", err)
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
imgUrl := imageResponse.Data[0].URL
imageRes, imageGetErr := http.Get(imgUrl)
if imageGetErr != nil {
return nil, imageGetErr
}
defer imageRes.Body.Close()
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
if imgReadErr != nil {
return nil, imgReadErr
}
// Save image to a temporary file
tmpfile, err := os.Create(filename)
if err != nil {
log.Fatal(err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write(imageBytes); err != nil {
log.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
log.Fatal(err)
}
// Upload the image to S3
_, uploadErr := UploadToS3(tmpfile.Name())
if uploadErr != nil {
log.Printf("Failed to upload image to S3: %v", uploadErr)
}
imageFile := bytes.NewBuffer(imageBytes)
return imageFile, nil
}

View file

@ -1,189 +0,0 @@
package lib
import (
"bytes"
"context"
"errors"
"io"
"log"
"net/http"
"os"
"strings"
"github.com/replicate/replicate-go"
)
var PromptPrefix = "Ready for a dose of sarcasm and wit? Himbot, your Discord assistant, is up for the challenge. Hit it with the prompt:"
func ReplicateTextGeneration(prompt string) (string, error) {
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
if clientError != nil {
return "", clientError
}
input := replicate.PredictionInput{
"prompt": PromptPrefix + prompt,
"max_new_tokens": 4096,
"prompt_template": "<s>[INST] {prompt} [/INST]",
}
webhook := replicate.Webhook{
URL: "https://example.com/webhook",
Events: []replicate.WebhookEventType{"start", "completed"},
}
prediction, predictionError := client.Run(context.Background(), "mistralai/mixtral-8x7b-instruct-v0.1:cf18decbf51c27fed6bbdc3492312c1c903222a56e3fe9ca02d6cbe5198afc10", input, &webhook)
if predictionError != nil {
return "", predictionError
}
if prediction == nil {
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
test, ok := prediction.([]interface{})
if !ok {
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
strs := make([]string, len(test))
for i, v := range test {
str, ok := v.(string)
if !ok {
return "", errors.New("element is not a string")
}
strs[i] = str
}
result := strings.Join(strs, "")
return result, nil
}
func ReplicateCodeGeneration(prompt string) (string, error) {
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
if clientError != nil {
return "", clientError
}
input := replicate.PredictionInput{
"prompt": PromptPrefix + prompt,
"max_new_tokens": 4096,
}
webhook := replicate.Webhook{
URL: "https://example.com/webhook",
Events: []replicate.WebhookEventType{"start", "completed"},
}
prediction, predictionError := client.Run(context.Background(), "meta/codellama-70b-instruct:a279116fe47a0f65701a8817188601e2fe8f4b9e04a518789655ea7b995851bf", input, &webhook)
if predictionError != nil {
return "", predictionError
}
if prediction == nil {
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
test, ok := prediction.([]interface{})
if !ok {
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
strs := make([]string, len(test))
for i, v := range test {
str, ok := v.(string)
if !ok {
return "", errors.New("element is not a string")
}
strs[i] = str
}
result := strings.Join(strs, "")
return result, nil
}
func ReplicateImageGeneration(prompt string, filename string) (*bytes.Buffer, error) {
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
if clientError != nil {
return nil, clientError
}
input := replicate.PredictionInput{
"width": 1024,
"height": 1024,
"prompt": prompt,
"scheduler": "K_EULER",
"num_outputs": 1,
"guidance_scale": 0,
"negative_prompt": "worst quality, low quality",
"num_inference_steps": 4,
"disable_safety_checker": true,
}
webhook := replicate.Webhook{
URL: "https://example.com/webhook",
Events: []replicate.WebhookEventType{"start", "completed"},
}
prediction, predictionError := client.Run(context.Background(), "lucataco/sdxl-lightning-4step:727e49a643e999d602a896c774a0658ffefea21465756a6ce24b7ea4165eba6a", input, &webhook)
if predictionError != nil {
return nil, predictionError
}
if prediction == nil {
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
test, ok := prediction.([]interface{})
if !ok {
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
imgUrl, ok := test[0].(string)
if !ok {
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
}
imageRes, imageGetErr := http.Get(imgUrl)
if imageGetErr != nil {
return nil, imageGetErr
}
defer imageRes.Body.Close()
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
if imgReadErr != nil {
return nil, imgReadErr
}
// Save image to a temporary file
tmpfile, err := os.Create(filename)
if err != nil {
log.Fatal(err)
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write(imageBytes); err != nil {
log.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
log.Fatal(err)
}
// Upload the image to S3
_, uploadErr := UploadToS3(tmpfile.Name())
if uploadErr != nil {
log.Printf("Failed to upload image to S3: %v", uploadErr)
}
imageFile := bytes.NewBuffer(imageBytes)
return imageFile, nil
}

12
main.go
View file

@ -31,17 +31,6 @@ var commands = []api.CreateCommandData{
},
},
},
{
Name: "code",
Description: "Ask Himbot programming questions! Cooldown: 2 Minutes.",
Options: []discord.CommandOption{
&discord.StringOption{
OptionName: "prompt",
Description: "The prompt to send to Himbot.",
Required: true,
},
},
},
{
Name: "pic",
Description: "Generate an image! Cooldown: 1 Minute.",
@ -107,7 +96,6 @@ func newHandler(s *state.State) *handler {
h.Use(cmdroute.Deferrable(s, cmdroute.DeferOpts{}))
h.AddFunc("ping", command.Ping)
h.AddFunc("ask", command.Ask)
h.AddFunc("code", command.Code)
h.AddFunc("pic", command.Pic)
h.AddFunc("hs", command.HS)