Cleaned up more business logic into the lib package

This commit is contained in:
Atridad Lahiji 2024-01-09 21:34:47 -07:00
parent 4f6b7894a9
commit 76300264e6
No known key found for this signature in database
5 changed files with 171 additions and 96 deletions

2
go.mod
View file

@ -5,11 +5,13 @@ go 1.21.6
require github.com/diamondburned/arikawa/v3 v3.3.4
require (
github.com/pkg/errors v0.9.1 // indirect
golang.org/x/net v0.20.0 // indirect
golang.org/x/sync v0.6.0 // indirect
)
require (
github.com/diamondburned/arikawa v1.3.14
github.com/gorilla/schema v1.2.1 // indirect
github.com/gorilla/websocket v1.5.1 // indirect
github.com/joho/godotenv v1.5.1

9
go.sum
View file

@ -1,7 +1,10 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/diamondburned/arikawa v1.3.14 h1:9Y1r8nlvWA01hQIYVFGr83JTLd5ULPBzNOKvxbAsH1M=
github.com/diamondburned/arikawa v1.3.14/go.mod h1:nIhVIatzTQhPUa7NB8w4koG1RF9gYbpAr8Fj8sKq660=
github.com/diamondburned/arikawa/v3 v3.3.4 h1:UXOjM7PRlWLJ8kVAydX/VetqV7W4/d4xU92JRy3SpU4=
github.com/diamondburned/arikawa/v3 v3.3.4/go.mod h1:5KMSeB9R2Kzi6K4EcqMz7mwAFpAi5jglX/Veq0+MPOo=
github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
github.com/gorilla/schema v1.2.1 h1:tjDxcmdb+siIqkTNoV+qRH2mjYdr2hHe5MKXbp61ziM=
github.com/gorilla/schema v1.2.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM=
@ -10,6 +13,8 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/replicate/replicate-go v0.14.2 h1:XgK+REvYrWs7qDeyugxHA93h31qBhEFk/3p1/p2w3W8=
@ -22,9 +27,11 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@ -40,6 +47,7 @@ golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -52,6 +60,7 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=

78
lib/openai.go Normal file
View file

@ -0,0 +1,78 @@
package lib
import (
"bytes"
"context"
"fmt"
"io"
"log"
"net/http"
"os"
"github.com/joho/godotenv"
"github.com/sashabaranov/go-openai"
)
var client *openai.Client
func init() {
godotenv.Load(".env")
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
fmt.Println("OPENAI_API_KEY environment variable not set")
os.Exit(1)
}
client = openai.NewClient(apiKey)
}
func OpenAITextGeneration(prompt string) (string, error) {
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT4TurboPreview,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: prompt,
},
},
},
)
if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
return "", err
}
return resp.Choices[0].Message.Content, nil
}
func OpenAIImageGeneration(prompt string) (*bytes.Buffer, error) {
// Send the generation request to DALL·E 3
resp, err := client.CreateImage(context.Background(), openai.ImageRequest{
Prompt: prompt,
Model: "dall-e-3",
Size: "1024x1024",
})
if err != nil {
log.Printf("Image creation error: %v\n", err)
return nil, fmt.Errorf("failed to generate image")
}
imageRes, err := http.Get(resp.Data[0].URL)
if err != nil {
return nil, err
}
defer imageRes.Body.Close()
imageBytes, err := io.ReadAll(imageRes.Body)
if err != nil {
return nil, err
}
imageFile := bytes.NewBuffer(imageBytes)
return imageFile, nil
}

59
lib/replicate.go Normal file
View file

@ -0,0 +1,59 @@
package lib
import (
"bytes"
"context"
"errors"
"io"
"net/http"
"github.com/replicate/replicate-go"
)
func ReplicateImageGeneration(prompt string) (*bytes.Buffer, error) {
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
if clientError != nil {
return nil, clientError
}
input := replicate.PredictionInput{
"prompt": prompt,
}
webhook := replicate.Webhook{
URL: "https://example.com/webhook",
Events: []replicate.WebhookEventType{"start", "completed"},
}
prediction, predictionError := client.Run(context.Background(), "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input, &webhook)
if predictionError != nil {
return nil, predictionError
}
test, ok := prediction.([]interface{})
if !ok {
return nil, errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements")
}
imgUrl, ok := test[0].(string)
if !ok {
return nil, errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements")
}
imageRes, imageGetErr := http.Get(imgUrl)
if imageGetErr != nil {
return nil, imageGetErr
}
defer imageRes.Body.Close()
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
if imgReadErr != nil {
return nil, imgReadErr
}
imageFile := bytes.NewBuffer(imageBytes)
return imageFile, nil
}

119
main.go
View file

@ -6,10 +6,8 @@ import (
"errors"
"fmt"
"himbot/lib"
"io"
"log"
"net"
"net/http"
"os"
"os/signal"
"time"
@ -22,11 +20,8 @@ import (
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/sendpart"
"github.com/joho/godotenv"
"github.com/replicate/replicate-go"
openai "github.com/sashabaranov/go-openai"
)
// Command metadata
var commands = []api.CreateCommandData{
{
Name: "ping",
@ -78,7 +73,6 @@ var commands = []api.CreateCommandData{
},
}
// Entrypoint
func main() {
godotenv.Load(".env")
@ -151,35 +145,33 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In
return errorResponse(err)
}
apiKey := os.Getenv("OPENAI_API_KEY")
client := openai.NewClient(apiKey)
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: openai.GPT4TurboPreview,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: options.Prompt,
},
},
},
)
respString, err := lib.OpenAITextGeneration(options.Prompt)
if err != nil {
return errorResponse(err)
fmt.Printf("ChatCompletion error: %v\n", err)
return &api.InteractionResponseData{
Content: option.NewNullableString("ChatCompletion Error!"),
AllowedMentions: &api.AllowedMentions{}, // don't mention anyone
}
}
respString := resp.Choices[0].Message.Content
if len(respString) > 1800 {
respString = respString[:1800]
}
textFile := bytes.NewBuffer([]byte(respString))
file := sendpart.File{
Name: "himbot_response.txt",
Reader: textFile,
}
return &api.InteractionResponseData{
Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response:\n"),
AllowedMentions: &api.AllowedMentions{}, // don't mention anyone
Files: []sendpart.File{file},
}
}
return &api.InteractionResponseData{
Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response: " + respString),
AllowedMentions: &api.AllowedMentions{},
AllowedMentions: &api.AllowedMentions{}, // don't mention anyone
}
}
@ -200,54 +192,12 @@ func (h *handler) cmdPic(ctx context.Context, data cmdroute.CommandData) *api.In
return errorResponse(err)
}
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
if clientError != nil {
return errorResponse(clientError)
}
if err := data.Options.Unmarshal(&options); err != nil {
imageFile, err := lib.ReplicateImageGeneration(options.Prompt)
if err != nil {
return errorResponse(err)
}
input := replicate.PredictionInput{
"prompt": options.Prompt,
}
webhook := replicate.Webhook{
URL: "https://example.com/webhook",
Events: []replicate.WebhookEventType{"start", "completed"},
}
prediction, predictionError := client.Run(context.Background(), "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input, &webhook)
if predictionError != nil {
return errorResponse(predictionError)
}
test, ok := prediction.([]interface{})
if !ok {
return errorResponse(errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements"))
}
imgUrl, ok := test[0].(string)
if !ok {
return errorResponse(errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements"))
}
imageRes, imageGetErr := http.Get(imgUrl)
if imageGetErr != nil {
return errorResponse(imageGetErr)
}
defer imageRes.Body.Close()
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
if imgReadErr != nil {
return errorResponse(imgReadErr)
}
imageFile := bytes.NewBuffer(imageBytes)
file := sendpart.File{
Name: "himbot_response.png",
Reader: imageFile,
@ -276,35 +226,12 @@ func (h *handler) cmdHDPic(ctx context.Context, data cmdroute.CommandData) *api.
return errorResponse(err)
}
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
// Send the generation request to DALL·E 3
resp, err := client.CreateImage(context.Background(), openai.ImageRequest{
Prompt: options.Prompt,
Model: "dall-e-3",
Size: "1024x1024",
})
if err != nil {
log.Printf("Image creation error: %v\n", err)
return errorResponse(fmt.Errorf("failed to generate image"))
}
imageRes, err := http.Get(resp.Data[0].URL)
imageFile, err := lib.OpenAIImageGeneration(options.Prompt)
if err != nil {
return errorResponse(err)
}
defer imageRes.Body.Close()
imageBytes, err := io.ReadAll(imageRes.Body)
if err != nil {
return errorResponse(err)
}
imageFile := bytes.NewBuffer(imageBytes)
file := sendpart.File{
Name: "himbot_response.png",
Reader: imageFile,