diff --git a/lib/replicate.go b/lib/replicate.go index bf7f33c..77deb5f 100644 --- a/lib/replicate.go +++ b/lib/replicate.go @@ -6,10 +6,51 @@ import ( "errors" "io" "net/http" + "strings" "github.com/replicate/replicate-go" ) +func ReplicateTextGeneration(prompt string) (string, error) { + client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) + if clientError != nil { + return "", clientError + } + + input := replicate.PredictionInput{ + "prompt": prompt, + } + webhook := replicate.Webhook{ + URL: "https://example.com/webhook", + Events: []replicate.WebhookEventType{"start", "completed"}, + } + + prediction, predictionError := client.Run(context.Background(), "mistralai/mixtral-8x7b-instruct-v0.1:7b3212fbaf88310cfef07a061ce94224e82efc8403c26fc67e8f6c065de51f21", input, &webhook) + + if predictionError != nil { + return "", predictionError + } + + test, ok := prediction.([]interface{}) + + if !ok { + return "", errors.New("there was an error generating the image based on this prompt... this usually happens when the generated image violates safety requirements") + } + + strs := make([]string, len(test)) + for i, v := range test { + str, ok := v.(string) + if !ok { + return "", errors.New("element is not a string") + } + strs[i] = str + } + + result := strings.Join(strs, "") + + return result, nil +} + func ReplicateImageGeneration(prompt string) (*bytes.Buffer, error) { client, clientError := replicate.NewClient(replicate.WithTokenFromEnv()) if clientError != nil { diff --git a/main.go b/main.go index 1e1c16c..0f3ad47 100644 --- a/main.go +++ b/main.go @@ -145,13 +145,13 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In return errorResponse(err) } - respString, err := lib.OpenAITextGeneration(options.Prompt) + respString, err := lib.ReplicateTextGeneration(options.Prompt) if err != nil { fmt.Printf("ChatCompletion error: %v\n", err) return &api.InteractionResponseData{ Content: option.NewNullableString("ChatCompletion Error!"), - AllowedMentions: &api.AllowedMentions{}, // don't mention anyone + AllowedMentions: &api.AllowedMentions{}, } } @@ -165,13 +165,13 @@ func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.In return &api.InteractionResponseData{ Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response:\n"), - AllowedMentions: &api.AllowedMentions{}, // don't mention anyone + AllowedMentions: &api.AllowedMentions{}, Files: []sendpart.File{file}, } } return &api.InteractionResponseData{ Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response: " + respString), - AllowedMentions: &api.AllowedMentions{}, // don't mention anyone + AllowedMentions: &api.AllowedMentions{}, } }