Replicate
This commit is contained in:
107
lib/openai.go
107
lib/openai.go
@ -1,107 +0,0 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
var PromptPrefix = "Your name is Himbot. You are a helpful but sarcastic and witty discord bot. Please respond with a natural response to the following prompt with that personality in mind:"
|
||||
|
||||
func OpenAITextGeneration(prompt string) (string, error) {
|
||||
godotenv.Load(".env")
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
|
||||
client := openai.NewClient(apiKey)
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: openai.GPT4Turbo1106,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: PromptPrefix + prompt,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Ask command error: %v\n", err)
|
||||
return "", errors.New("https://fly.storage.tigris.dev/atridad/himbot/no.gif")
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func OpenAIImageGeneration(prompt string, filename string) (*bytes.Buffer, error) {
|
||||
godotenv.Load(".env")
|
||||
apiKey := os.Getenv("OPENAI_API_KEY")
|
||||
|
||||
client := openai.NewClient(apiKey)
|
||||
|
||||
imageResponse, err := client.CreateImage(
|
||||
context.Background(),
|
||||
openai.ImageRequest{
|
||||
Model: openai.CreateImageModelDallE3,
|
||||
Prompt: prompt,
|
||||
Size: openai.CreateImageSize1024x1024,
|
||||
Quality: openai.CreateImageQualityStandard,
|
||||
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||
N: 1,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Pic command error: %v\n", err)
|
||||
return nil, errors.New("https://fly.storage.tigris.dev/atridad/himbot/hornypolice.gif")
|
||||
}
|
||||
|
||||
imgUrl := imageResponse.Data[0].URL
|
||||
|
||||
imageRes, imageGetErr := http.Get(imgUrl)
|
||||
if imageGetErr != nil {
|
||||
return nil, imageGetErr
|
||||
}
|
||||
|
||||
defer imageRes.Body.Close()
|
||||
|
||||
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
|
||||
if imgReadErr != nil {
|
||||
return nil, imgReadErr
|
||||
}
|
||||
|
||||
// Save image to a temporary file
|
||||
tmpfile, err := os.Create(filename)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write(imageBytes); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Upload the image to S3
|
||||
_, uploadErr := UploadToS3(tmpfile.Name())
|
||||
if uploadErr != nil {
|
||||
log.Printf("Failed to upload image to S3: %v", uploadErr)
|
||||
}
|
||||
|
||||
imageFile := bytes.NewBuffer(imageBytes)
|
||||
return imageFile, nil
|
||||
}
|
147
lib/replicate.go
Normal file
147
lib/replicate.go
Normal file
@ -0,0 +1,147 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/replicate/replicate-go"
|
||||
)
|
||||
|
||||
var ReplicatePromptPrefix = "Your name is Himbot. You are a helpful but sarcastic and witty discord bot. Please respond with a natural response to the following prompt with that personality in mind:"
|
||||
|
||||
func ReplicateTextGeneration(prompt string) (string, error) {
|
||||
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
|
||||
if clientError != nil {
|
||||
return "", clientError
|
||||
}
|
||||
|
||||
input := replicate.PredictionInput{
|
||||
"prompt": ReplicatePromptPrefix + prompt,
|
||||
"max_new_tokens": 4096,
|
||||
"prompt_template": "<s>[INST] {prompt} [/INST]",
|
||||
}
|
||||
|
||||
webhook := replicate.Webhook{
|
||||
URL: "https://example.com/webhook",
|
||||
Events: []replicate.WebhookEventType{"start", "completed"},
|
||||
}
|
||||
|
||||
prediction, predictionError := client.Run(context.Background(), "mistralai/mixtral-8x7b-instruct-v0.1:cf18decbf51c27fed6bbdc3492312c1c903222a56e3fe9ca02d6cbe5198afc10", input, &webhook)
|
||||
|
||||
if predictionError != nil {
|
||||
return "", predictionError
|
||||
}
|
||||
|
||||
if prediction == nil {
|
||||
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
test, ok := prediction.([]interface{})
|
||||
|
||||
if !ok {
|
||||
return "", errors.New("there was an error generating a response based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
strs := make([]string, len(test))
|
||||
for i, v := range test {
|
||||
str, ok := v.(string)
|
||||
if !ok {
|
||||
return "", errors.New("element is not a string")
|
||||
}
|
||||
strs[i] = str
|
||||
}
|
||||
|
||||
result := strings.Join(strs, "")
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func ReplicateImageGeneration(prompt string, filename string) (*bytes.Buffer, error) {
|
||||
client, clientError := replicate.NewClient(replicate.WithTokenFromEnv())
|
||||
if clientError != nil {
|
||||
return nil, clientError
|
||||
}
|
||||
|
||||
input := replicate.PredictionInput{
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"prompt": prompt,
|
||||
"scheduler": "K_EULER",
|
||||
"num_outputs": 1,
|
||||
"guidance_scale": 7.5,
|
||||
"lora_scale": 0.65,
|
||||
"lora_weights": "https://replicate.delivery/pbxt/hM1H6f93HCVYQq471gZz6EYtRHPMJYAsyxeQXdGnozeDJKOkA/trained_model.tar",
|
||||
"negative_prompt": "ugly, deformed, noisy, blurry, low contrast, text, BadDream, lowres, low resolution, mutated body parts, extra limbs, mutated body parts, inaccurate hands, too many hands, deformed fingers, too many fingers, deformed eyes, deformed faces, unrealistic faces",
|
||||
"num_inference_steps": 35,
|
||||
"disable_safety_checker": true,
|
||||
}
|
||||
webhook := replicate.Webhook{
|
||||
URL: "https://example.com/webhook",
|
||||
Events: []replicate.WebhookEventType{"start", "completed"},
|
||||
}
|
||||
|
||||
prediction, predictionError := client.Run(context.Background(), "batouresearch/open-dalle-1.1-lora:2ade2cbfc88298b98366a6e361559e11666c17ed415d341c9ae776b30a61b196", input, &webhook)
|
||||
|
||||
if predictionError != nil {
|
||||
return nil, predictionError
|
||||
}
|
||||
|
||||
if prediction == nil {
|
||||
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
test, ok := prediction.([]interface{})
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
imgUrl, ok := test[0].(string)
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("there was an error generating the image based on this prompt... please reach out to @himbothyswaggins to fix this issue")
|
||||
}
|
||||
|
||||
imageRes, imageGetErr := http.Get(imgUrl)
|
||||
if imageGetErr != nil {
|
||||
return nil, imageGetErr
|
||||
}
|
||||
|
||||
defer imageRes.Body.Close()
|
||||
|
||||
imageBytes, imgReadErr := io.ReadAll(imageRes.Body)
|
||||
if imgReadErr != nil {
|
||||
return nil, imgReadErr
|
||||
}
|
||||
|
||||
// Save image to a temporary file
|
||||
tmpfile, err := os.Create(filename)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write(imageBytes); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Upload the image to S3
|
||||
_, uploadErr := UploadToS3(tmpfile.Name())
|
||||
if uploadErr != nil {
|
||||
log.Printf("Failed to upload image to S3: %v", uploadErr)
|
||||
}
|
||||
|
||||
imageFile := bytes.NewBuffer(imageBytes)
|
||||
return imageFile, nil
|
||||
}
|
Reference in New Issue
Block a user