From 76300264e6d8eef7a89bede7f2feb33b78c67acc Mon Sep 17 00:00:00 2001 From: atridadl Date: Tue, 9 Jan 2024 21:34:47 -0700 Subject: [PATCH] Cleaned up more business logic into the lib package --- go.mod | 2 + go.sum | 9 ++++ lib/openai.go | 78 +++++++++++++++++++++++++++++++ lib/replicate.go | 59 +++++++++++++++++++++++ main.go | 119 +++++++++-------------------------------------- 5 files changed, 171 insertions(+), 96 deletions(-) create mode 100644 lib/openai.go create mode 100644 lib/replicate.go diff --git a/go.mod b/go.mod index db78ec4..6daf861 100644 --- a/go.mod +++ b/go.mod @@ -5,11 +5,13 @@ go 1.21.6 require github.com/diamondburned/arikawa/v3 v3.3.4 require ( + github.com/pkg/errors v0.9.1 // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/sync v0.6.0 // indirect ) require ( + github.com/diamondburned/arikawa v1.3.14 github.com/gorilla/schema v1.2.1 // indirect github.com/gorilla/websocket v1.5.1 // indirect github.com/joho/godotenv v1.5.1 diff --git a/go.sum b/go.sum index 33ace0a..2c56564 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/diamondburned/arikawa v1.3.14 h1:9Y1r8nlvWA01hQIYVFGr83JTLd5ULPBzNOKvxbAsH1M= +github.com/diamondburned/arikawa v1.3.14/go.mod h1:nIhVIatzTQhPUa7NB8w4koG1RF9gYbpAr8Fj8sKq660= github.com/diamondburned/arikawa/v3 v3.3.4 h1:UXOjM7PRlWLJ8kVAydX/VetqV7W4/d4xU92JRy3SpU4= github.com/diamondburned/arikawa/v3 v3.3.4/go.mod h1:5KMSeB9R2Kzi6K4EcqMz7mwAFpAi5jglX/Veq0+MPOo= +github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/gorilla/schema v1.2.1 h1:tjDxcmdb+siIqkTNoV+qRH2mjYdr2hHe5MKXbp61ziM= github.com/gorilla/schema v1.2.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM= @@ -10,6 +13,8 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/ github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/replicate/replicate-go v0.14.2 h1:XgK+REvYrWs7qDeyugxHA93h31qBhEFk/3p1/p2w3W8= @@ -22,9 +27,11 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -40,6 +47,7 @@ golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -52,6 +60,7 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= diff --git a/lib/openai.go b/lib/openai.go new file mode 100644 index 0000000..0c20d3d --- /dev/null +++ b/lib/openai.go @@ -0,0 +1,78 @@ +package lib + +import ( + "bytes" + "context" + "fmt" + "io" + "log" + "net/http" + "os" + + "github.com/joho/godotenv" + "github.com/sashabaranov/go-openai" +) + +var client *openai.Client + +func init() { + godotenv.Load(".env") + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + fmt.Println("OPENAI_API_KEY environment variable not set") + os.Exit(1) + } + client = openai.NewClient(apiKey) +} + +func OpenAITextGeneration(prompt string) (string, error) { + resp, err := client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: openai.GPT4TurboPreview, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: prompt, + }, + }, + }, + ) + + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + return "", err + } + + return resp.Choices[0].Message.Content, nil +} + +func OpenAIImageGeneration(prompt string) (*bytes.Buffer, error) { + // Send the generation request to DALL·E 3 + resp, err := client.CreateImage(context.Background(), openai.ImageRequest{ + Prompt: prompt, + Model: "dall-e-3", + Size: "1024x1024", + }) + if err != nil { + log.Printf("Image creation error: %v\n", err) + return nil, fmt.Errorf("failed to generate image") + } + + imageRes, err := http.Get(resp.Data[0].URL) + + if err != nil { + return nil, err + } + + defer imageRes.Body.Close() + + imageBytes, err := io.ReadAll(imageRes.Body) + + if err != nil { + return nil, err + } + + imageFile := bytes.NewBuffer(imageBytes) + return imageFile, nil +} diff --git a/lib/replicate.go b/lib/replicate.go new file mode 100644 index 0000000..bf7f33c --- /dev/null +++ b/lib/replicate.go @@ -0,0 +1,59 @@ +package lib + +import ( + "bytes" + "context" + "errors" + "io" + "net/http" + + "github.com/replicate/replicate-go" +) + +func ReplicateImageGeneration(prompt string) (*bytes.Buffer, error) { + client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) + if clientError != nil { + return nil, clientError + } + + input := replicate.PredictionInput{ + "prompt": prompt, + } + webhook := replicate.Webhook{ + URL: "https://example.com/webhook", + Events: []replicate.WebhookEventType{"start", "completed"}, + } + + prediction, predictionError := client.Run(context.Background(), "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input, &webhook) + + if predictionError != nil { + return nil, predictionError + } + + test, ok := prediction.([]interface{}) + + if !ok { + return nil, errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements") + } + + imgUrl, ok := test[0].(string) + + if !ok { + return nil, errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements") + } + + imageRes, imageGetErr := http.Get(imgUrl) + if imageGetErr != nil { + return nil, imageGetErr + } + + defer imageRes.Body.Close() + + imageBytes, imgReadErr := io.ReadAll(imageRes.Body) + if imgReadErr != nil { + return nil, imgReadErr + } + + imageFile := bytes.NewBuffer(imageBytes) + return imageFile, nil +} diff --git a/main.go b/main.go index 2843859..1e1c16c 100644 --- a/main.go +++ b/main.go @@ -6,10 +6,8 @@ import ( "errors" "fmt" "himbot/lib" - "io" "log" "net" - "net/http" "os" "os/signal" "time" @@ -22,11 +20,8 @@ import ( "github.com/diamondburned/arikawa/v3/utils/json/option" "github.com/diamondburned/arikawa/v3/utils/sendpart" "github.com/joho/godotenv" - "github.com/replicate/replicate-go" - openai "github.com/sashabaranov/go-openai" ) -// Command metadata var commands = []api.CreateCommandData{ { Name: "ping", @@ -78,7 +73,6 @@ var commands = []api.CreateCommandData{ }, } -// Entrypoint func main() { godotenv.Load(".env") @@ -151,35 +145,33 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In return errorResponse(err) } - apiKey := os.Getenv("OPENAI_API_KEY") - client := openai.NewClient(apiKey) - - resp, err := client.CreateChatCompletion( - context.Background(), - openai.ChatCompletionRequest{ - Model: openai.GPT4TurboPreview, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: options.Prompt, - }, - }, - }, - ) + respString, err := lib.OpenAITextGeneration(options.Prompt) if err != nil { - return errorResponse(err) + fmt.Printf("ChatCompletion error: %v\n", err) + return &api.InteractionResponseData{ + Content: option.NewNullableString("ChatCompletion Error!"), + AllowedMentions: &api.AllowedMentions{}, // don't mention anyone + } } - respString := resp.Choices[0].Message.Content - if len(respString) > 1800 { - respString = respString[:1800] - } + textFile := bytes.NewBuffer([]byte(respString)) + file := sendpart.File{ + Name: "himbot_response.txt", + Reader: textFile, + } + + return &api.InteractionResponseData{ + Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response:\n"), + AllowedMentions: &api.AllowedMentions{}, // don't mention anyone + Files: []sendpart.File{file}, + } + } return &api.InteractionResponseData{ Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response: " + respString), - AllowedMentions: &api.AllowedMentions{}, + AllowedMentions: &api.AllowedMentions{}, // don't mention anyone } } @@ -200,54 +192,12 @@ func (h *handler) cmdPic(ctx context.Context, data cmdroute.CommandData) *api.In return errorResponse(err) } - client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) - if clientError != nil { - return errorResponse(clientError) - } - if err := data.Options.Unmarshal(&options); err != nil { + imageFile, err := lib.ReplicateImageGeneration(options.Prompt) + + if err != nil { return errorResponse(err) } - input := replicate.PredictionInput{ - "prompt": options.Prompt, - } - webhook := replicate.Webhook{ - URL: "https://example.com/webhook", - Events: []replicate.WebhookEventType{"start", "completed"}, - } - - prediction, predictionError := client.Run(context.Background(), "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input, &webhook) - - if predictionError != nil { - return errorResponse(predictionError) - } - - test, ok := prediction.([]interface{}) - - if !ok { - return errorResponse(errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements")) - } - - imgUrl, ok := test[0].(string) - - if !ok { - return errorResponse(errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements")) - } - - imageRes, imageGetErr := http.Get(imgUrl) - if imageGetErr != nil { - return errorResponse(imageGetErr) - } - - defer imageRes.Body.Close() - - imageBytes, imgReadErr := io.ReadAll(imageRes.Body) - if imgReadErr != nil { - return errorResponse(imgReadErr) - } - - imageFile := bytes.NewBuffer(imageBytes) - file := sendpart.File{ Name: "himbot_response.png", Reader: imageFile, @@ -276,35 +226,12 @@ func (h *handler) cmdHDPic(ctx context.Context, data cmdroute.CommandData) *api. return errorResponse(err) } - client := openai.NewClient(os.Getenv("OPENAI_API_KEY")) - - // Send the generation request to DALL·E 3 - resp, err := client.CreateImage(context.Background(), openai.ImageRequest{ - Prompt: options.Prompt, - Model: "dall-e-3", - Size: "1024x1024", - }) - if err != nil { - log.Printf("Image creation error: %v\n", err) - return errorResponse(fmt.Errorf("failed to generate image")) - } - - imageRes, err := http.Get(resp.Data[0].URL) + imageFile, err := lib.OpenAIImageGeneration(options.Prompt) if err != nil { return errorResponse(err) } - defer imageRes.Body.Close() - - imageBytes, err := io.ReadAll(imageRes.Body) - - if err != nil { - return errorResponse(err) - } - - imageFile := bytes.NewBuffer(imageBytes) - file := sendpart.File{ Name: "himbot_response.png", Reader: imageFile,