Cleaned up more business logic into the lib package
This commit is contained in:
parent
4f6b7894a9
commit
76300264e6
5 changed files with 171 additions and 96 deletions
2
go.mod
2
go.mod
|
@ -5,11 +5,13 @@ go 1.21.6
|
|||
require github.com/diamondburned/arikawa/v3 v3.3.4
|
||||
|
||||
require (
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
golang.org/x/net v0.20.0 // indirect
|
||||
golang.org/x/sync v0.6.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/diamondburned/arikawa v1.3.14
|
||||
github.com/gorilla/schema v1.2.1 // indirect
|
||||
github.com/gorilla/websocket v1.5.1 // indirect
|
||||
github.com/joho/godotenv v1.5.1
|
||||
|
|
9
go.sum
9
go.sum
|
@ -1,7 +1,10 @@
|
|||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/diamondburned/arikawa v1.3.14 h1:9Y1r8nlvWA01hQIYVFGr83JTLd5ULPBzNOKvxbAsH1M=
|
||||
github.com/diamondburned/arikawa v1.3.14/go.mod h1:nIhVIatzTQhPUa7NB8w4koG1RF9gYbpAr8Fj8sKq660=
|
||||
github.com/diamondburned/arikawa/v3 v3.3.4 h1:UXOjM7PRlWLJ8kVAydX/VetqV7W4/d4xU92JRy3SpU4=
|
||||
github.com/diamondburned/arikawa/v3 v3.3.4/go.mod h1:5KMSeB9R2Kzi6K4EcqMz7mwAFpAi5jglX/Veq0+MPOo=
|
||||
github.com/gorilla/schema v1.1.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
|
||||
github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU=
|
||||
github.com/gorilla/schema v1.2.1 h1:tjDxcmdb+siIqkTNoV+qRH2mjYdr2hHe5MKXbp61ziM=
|
||||
github.com/gorilla/schema v1.2.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM=
|
||||
|
@ -10,6 +13,8 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/
|
|||
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/replicate/replicate-go v0.14.2 h1:XgK+REvYrWs7qDeyugxHA93h31qBhEFk/3p1/p2w3W8=
|
||||
|
@ -22,9 +27,11 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
|
|||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200423211502-4bdfaf469ed5/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
|
@ -40,6 +47,7 @@ golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
|||
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
@ -52,6 +60,7 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
|||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
|
|
78
lib/openai.go
Normal file
78
lib/openai.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
var client *openai.Client
|
||||
|
||||
func init() {
|
||||
godotenv.Load(".env")
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
if apiKey == "" {
|
||||
fmt.Println("OPENAI_API_KEY environment variable not set")
|
||||
os.Exit(1)
|
||||
}
|
||||
client = openai.NewClient(apiKey)
|
||||
}
|
||||
|
||||
func OpenAITextGeneration(prompt string) (string, error) {
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: openai.GPT4TurboPreview,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: prompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("ChatCompletion error: %v\n", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func OpenAIImageGeneration(prompt string) (*bytes.Buffer, error) {
|
||||
// Send the generation request to DALL·E 3
|
||||
resp, err := client.CreateImage(context.Background(), openai.ImageRequest{
|
||||
Prompt: prompt,
|
||||
Model: "dall-e-3",
|
||||
Size: "1024x1024",
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Image creation error: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to generate image")
|
||||
}
|
||||
|
||||
imageRes, err := http.Get(resp.Data[0].URL)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer imageRes.Body.Close()
|
||||
|
||||
imageBytes, err := io.ReadAll(imageRes.Body)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
imageFile := bytes.NewBuffer(imageBytes)
|
||||
return imageFile, nil
|
||||
}
|
59
lib/replicate.go
Normal file
59
lib/replicate.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/replicate/replicate-go"
|
||||
)
|
||||
|
||||
func ReplicateImageGeneration(prompt string) (*bytes.Buffer, error) {
|
||||
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
|
||||
if clientError != nil {
|
||||
return nil, clientError
|
||||
}
|
||||
|
||||
input := replicate.PredictionInput{
|
||||
"prompt": prompt,
|
||||
}
|
||||
webhook := replicate.Webhook{
|
||||
URL: "https://example.com/webhook",
|
||||
Events: []replicate.WebhookEventType{"start", "completed"},
|
||||
}
|
||||
|
||||
prediction, predictionError := client.Run(context.Background(), "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input, &webhook)
|
||||
|
||||
if predictionError != nil {
|
||||
return nil, predictionError
|
||||
}
|
||||
|
||||
test, ok := prediction.([]interface{})
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements")
|
||||
}
|
||||
|
||||
imgUrl, ok := test[0].(string)
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements")
|
||||
}
|
||||
|
||||
imageRes, imageGetErr := http.Get(imgUrl)
|
||||
if imageGetErr != nil {
|
||||
return nil, imageGetErr
|
||||
}
|
||||
|
||||
defer imageRes.Body.Close()
|
||||
|
||||
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
|
||||
if imgReadErr != nil {
|
||||
return nil, imgReadErr
|
||||
}
|
||||
|
||||
imageFile := bytes.NewBuffer(imageBytes)
|
||||
return imageFile, nil
|
||||
}
|
119
main.go
119
main.go
|
@ -6,10 +6,8 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"himbot/lib"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
@ -22,11 +20,8 @@ import (
|
|||
"github.com/diamondburned/arikawa/v3/utils/json/option"
|
||||
"github.com/diamondburned/arikawa/v3/utils/sendpart"
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/replicate/replicate-go"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
// Command metadata
|
||||
var commands = []api.CreateCommandData{
|
||||
{
|
||||
Name: "ping",
|
||||
|
@ -78,7 +73,6 @@ var commands = []api.CreateCommandData{
|
|||
},
|
||||
}
|
||||
|
||||
// Entrypoint
|
||||
func main() {
|
||||
godotenv.Load(".env")
|
||||
|
||||
|
@ -151,35 +145,33 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In
|
|||
return errorResponse(err)
|
||||
}
|
||||
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
client := openai.NewClient(apiKey)
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: openai.GPT4TurboPreview,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: options.Prompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
respString, err := lib.OpenAITextGeneration(options.Prompt)
|
||||
|
||||
if err != nil {
|
||||
return errorResponse(err)
|
||||
fmt.Printf("ChatCompletion error: %v\n", err)
|
||||
return &api.InteractionResponseData{
|
||||
Content: option.NewNullableString("ChatCompletion Error!"),
|
||||
AllowedMentions: &api.AllowedMentions{}, // don't mention anyone
|
||||
}
|
||||
}
|
||||
|
||||
respString := resp.Choices[0].Message.Content
|
||||
|
||||
if len(respString) > 1800 {
|
||||
respString = respString[:1800]
|
||||
}
|
||||
textFile := bytes.NewBuffer([]byte(respString))
|
||||
|
||||
file := sendpart.File{
|
||||
Name: "himbot_response.txt",
|
||||
Reader: textFile,
|
||||
}
|
||||
|
||||
return &api.InteractionResponseData{
|
||||
Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response:\n"),
|
||||
AllowedMentions: &api.AllowedMentions{}, // don't mention anyone
|
||||
Files: []sendpart.File{file},
|
||||
}
|
||||
}
|
||||
return &api.InteractionResponseData{
|
||||
Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response: " + respString),
|
||||
AllowedMentions: &api.AllowedMentions{},
|
||||
AllowedMentions: &api.AllowedMentions{}, // don't mention anyone
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -200,54 +192,12 @@ func (h *handler) cmdPic(ctx context.Context, data cmdroute.CommandData) *api.In
|
|||
return errorResponse(err)
|
||||
}
|
||||
|
||||
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
|
||||
if clientError != nil {
|
||||
return errorResponse(clientError)
|
||||
}
|
||||
if err := data.Options.Unmarshal(&options); err != nil {
|
||||
imageFile, err := lib.ReplicateImageGeneration(options.Prompt)
|
||||
|
||||
if err != nil {
|
||||
return errorResponse(err)
|
||||
}
|
||||
|
||||
input := replicate.PredictionInput{
|
||||
"prompt": options.Prompt,
|
||||
}
|
||||
webhook := replicate.Webhook{
|
||||
URL: "https://example.com/webhook",
|
||||
Events: []replicate.WebhookEventType{"start", "completed"},
|
||||
}
|
||||
|
||||
prediction, predictionError := client.Run(context.Background(), "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b", input, &webhook)
|
||||
|
||||
if predictionError != nil {
|
||||
return errorResponse(predictionError)
|
||||
}
|
||||
|
||||
test, ok := prediction.([]interface{})
|
||||
|
||||
if !ok {
|
||||
return errorResponse(errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements"))
|
||||
}
|
||||
|
||||
imgUrl, ok := test[0].(string)
|
||||
|
||||
if !ok {
|
||||
return errorResponse(errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements"))
|
||||
}
|
||||
|
||||
imageRes, imageGetErr := http.Get(imgUrl)
|
||||
if imageGetErr != nil {
|
||||
return errorResponse(imageGetErr)
|
||||
}
|
||||
|
||||
defer imageRes.Body.Close()
|
||||
|
||||
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
|
||||
if imgReadErr != nil {
|
||||
return errorResponse(imgReadErr)
|
||||
}
|
||||
|
||||
imageFile := bytes.NewBuffer(imageBytes)
|
||||
|
||||
file := sendpart.File{
|
||||
Name: "himbot_response.png",
|
||||
Reader: imageFile,
|
||||
|
@ -276,35 +226,12 @@ func (h *handler) cmdHDPic(ctx context.Context, data cmdroute.CommandData) *api.
|
|||
return errorResponse(err)
|
||||
}
|
||||
|
||||
client := openai.NewClient(os.Getenv("OPENAI_API_KEY"))
|
||||
|
||||
// Send the generation request to DALL·E 3
|
||||
resp, err := client.CreateImage(context.Background(), openai.ImageRequest{
|
||||
Prompt: options.Prompt,
|
||||
Model: "dall-e-3",
|
||||
Size: "1024x1024",
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Image creation error: %v\n", err)
|
||||
return errorResponse(fmt.Errorf("failed to generate image"))
|
||||
}
|
||||
|
||||
imageRes, err := http.Get(resp.Data[0].URL)
|
||||
imageFile, err := lib.OpenAIImageGeneration(options.Prompt)
|
||||
|
||||
if err != nil {
|
||||
return errorResponse(err)
|
||||
}
|
||||
|
||||
defer imageRes.Body.Close()
|
||||
|
||||
imageBytes, err := io.ReadAll(imageRes.Body)
|
||||
|
||||
if err != nil {
|
||||
return errorResponse(err)
|
||||
}
|
||||
|
||||
imageFile := bytes.NewBuffer(imageBytes)
|
||||
|
||||
file := sendpart.File{
|
||||
Name: "himbot_response.png",
|
||||
Reader: imageFile,
|
||||
|
|
Loading…
Add table
Reference in a new issue