diff --git a/command/ask.go b/command/ask.go new file mode 100644 index 0000000..2166b1a --- /dev/null +++ b/command/ask.go @@ -0,0 +1,64 @@ +package command + +import ( + "bytes" + "context" + "errors" + "fmt" + "himbot/lib" + "time" + + "github.com/diamondburned/arikawa/v3/api" + "github.com/diamondburned/arikawa/v3/api/cmdroute" + "github.com/diamondburned/arikawa/v3/utils/json/option" + "github.com/diamondburned/arikawa/v3/utils/sendpart" +) + +func Ask(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { + // Cooldown Logic + allowed, cooldownString := lib.CooldownHandler(*data.Event, "ask", time.Minute) + + if !allowed { + return lib.ErrorResponse(errors.New(cooldownString)) + } + + // Command Logic + var options struct { + Prompt string `discord:"prompt"` + } + + if err := data.Options.Unmarshal(&options); err != nil { + lib.CancelCooldown(data.Event.User.ID.String(), "ask") + return lib.ErrorResponse(err) + } + + respString, err := lib.ReplicateTextGeneration(options.Prompt) + + if err != nil { + fmt.Printf("ChatCompletion error: %v\n", err) + lib.CancelCooldown(data.Event.User.ID.String(), "ask") + return &api.InteractionResponseData{ + Content: option.NewNullableString("ChatCompletion Error!"), + AllowedMentions: &api.AllowedMentions{}, + } + } + + if len(respString) > 1800 { + textFile := bytes.NewBuffer([]byte(respString)) + + file := sendpart.File{ + Name: "himbot_response.txt", + Reader: textFile, + } + + return &api.InteractionResponseData{ + Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response:\n"), + AllowedMentions: &api.AllowedMentions{}, + Files: []sendpart.File{file}, + } + } + return &api.InteractionResponseData{ + Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response: " + respString), + AllowedMentions: &api.AllowedMentions{}, + } +} diff --git a/command/hs.go b/command/hs.go new file mode 100644 index 0000000..b1d0705 --- /dev/null +++ b/command/hs.go @@ -0,0 +1,26 @@ +package command + +import ( + "context" + "himbot/lib" + + "github.com/diamondburned/arikawa/v3/api" + "github.com/diamondburned/arikawa/v3/api/cmdroute" + "github.com/diamondburned/arikawa/v3/utils/json/option" +) + +func HS(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { + var options struct { + Arg string `discord:"nickname"` + } + + if err := data.Options.Unmarshal(&options); err != nil { + return lib.ErrorResponse(err) + } + + user := lib.GetUserObject(*data.Event) + + return &api.InteractionResponseData{ + Content: option.NewNullableString(options.Arg + " was " + user.DisplayName() + "'s nickname in highschool!"), + } +} diff --git a/command/pic.go b/command/pic.go new file mode 100644 index 0000000..ee1d58f --- /dev/null +++ b/command/pic.go @@ -0,0 +1,56 @@ +package command + +import ( + "context" + "errors" + "himbot/lib" + "strconv" + "time" + + "github.com/diamondburned/arikawa/v3/api" + "github.com/diamondburned/arikawa/v3/api/cmdroute" + "github.com/diamondburned/arikawa/v3/utils/json/option" + "github.com/diamondburned/arikawa/v3/utils/sendpart" +) + +func Pic(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { + // Cooldown Logic + allowed, cooldownString := lib.CooldownHandler(*data.Event, "pic", time.Minute*2) + + if !allowed { + return lib.ErrorResponse(errors.New(cooldownString)) + } + + // Command Logic + var options struct { + Prompt string `discord:"prompt"` + } + + if err := data.Options.Unmarshal(&options); err != nil { + lib.CancelCooldown(data.Event.User.ID.String(), "pic") + return lib.ErrorResponse(err) + } + + // Get current epoch timestamp + timestamp := strconv.FormatInt(time.Now().Unix(), 10) + + // Concatenate clean username and timestamp to form filename + filename := data.Event.Sender().Username + "_" + timestamp + ".jpg" + + imageFile, err := lib.ReplicateImageGeneration(options.Prompt, filename) + + if err != nil { + lib.CancelCooldown(data.Event.User.ID.String(), "pic") + return lib.ErrorResponse(err) + } + + file := sendpart.File{ + Name: filename, + Reader: imageFile, + } + + return &api.InteractionResponseData{ + Content: option.NewNullableString("Prompt: " + options.Prompt), + Files: []sendpart.File{file}, + } +} diff --git a/command/ping.go b/command/ping.go new file mode 100644 index 0000000..3a894da --- /dev/null +++ b/command/ping.go @@ -0,0 +1,16 @@ +package command + +import ( + "context" + + "github.com/diamondburned/arikawa/v3/api" + "github.com/diamondburned/arikawa/v3/api/cmdroute" + "github.com/diamondburned/arikawa/v3/utils/json/option" +) + +func Ping(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { + // Command Logic + return &api.InteractionResponseData{ + Content: option.NewNullableString("Pong!"), + } +} diff --git a/go.sum b/go.sum index 1ec78c4..8af460c 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,7 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/ github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= @@ -51,6 +52,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= @@ -60,5 +63,7 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lib/errors.go b/lib/errors.go new file mode 100644 index 0000000..a491ebb --- /dev/null +++ b/lib/errors.go @@ -0,0 +1,28 @@ +package lib + +import ( + "net" + "os" + + "github.com/diamondburned/arikawa/v3/api" + "github.com/diamondburned/arikawa/v3/discord" + "github.com/diamondburned/arikawa/v3/utils/json/option" +) + +func ErrorResponse(err error) *api.InteractionResponseData { + var content string + switch e := err.(type) { + case *net.OpError: + content = "**Network Error:** " + e.Error() + case *os.PathError: + content = "**File Error:** " + e.Error() + default: + content = "**Error:** " + err.Error() + } + + return &api.InteractionResponseData{ + Content: option.NewNullableString(content), + Flags: discord.EphemeralMessage, + AllowedMentions: &api.AllowedMentions{}, + } +} diff --git a/lib/replicate.go b/lib/replicate.go index eb3d44d..cce3089 100644 --- a/lib/replicate.go +++ b/lib/replicate.go @@ -20,15 +20,14 @@ func ReplicateTextGeneration(prompt string) (string, error) { } input := replicate.PredictionInput{ - "prompt": prompt, - "max_new_tokens": 4096, + "prompt": prompt, } webhook := replicate.Webhook{ URL: "https://example.com/webhook", Events: []replicate.WebhookEventType{"start", "completed"}, } - prediction, predictionError := client.Run(context.Background(), "mistralai/mistral-7b-instruct-v0.2:79052a3adbba8116ebc6697dcba67ad0d58feff23e7aeb2f103fc9aa545f9269", input, &webhook) + prediction, predictionError := client.Run(context.Background(), "meta/llama-2-70b-chat:2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48", input, &webhook) if predictionError != nil { return "", predictionError diff --git a/main.go b/main.go index 7a3b950..c44f77a 100644 --- a/main.go +++ b/main.go @@ -1,25 +1,17 @@ package main import ( - "bytes" "context" - "errors" - "fmt" - "himbot/lib" + "himbot/command" "log" - "net" "os" "os/signal" - "strconv" - "time" "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/api/cmdroute" "github.com/diamondburned/arikawa/v3/discord" "github.com/diamondburned/arikawa/v3/gateway" "github.com/diamondburned/arikawa/v3/state" - "github.com/diamondburned/arikawa/v3/utils/json/option" - "github.com/diamondburned/arikawa/v3/utils/sendpart" "github.com/joho/godotenv" ) @@ -102,142 +94,10 @@ func newHandler(s *state.State) *handler { h.Router = cmdroute.NewRouter() // Automatically defer handles if they're slow. h.Use(cmdroute.Deferrable(s, cmdroute.DeferOpts{})) - h.AddFunc("ping", h.cmdPing) - h.AddFunc("ask", h.cmdAsk) - h.AddFunc("pic", h.cmdPic) - h.AddFunc("hs", h.cmdHS) + h.AddFunc("ping", command.Ping) + h.AddFunc("ask", command.Ask) + h.AddFunc("pic", command.Pic) + h.AddFunc("hs", command.HS) return h } - -func (h *handler) cmdPing(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { - // Command Logic - return &api.InteractionResponseData{ - Content: option.NewNullableString("Pong!"), - } -} - -func (h *handler) cmdAsk(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { - // Cooldown Logic - allowed, cooldownString := lib.CooldownHandler(*data.Event, "ask", time.Minute) - - if !allowed { - return errorResponse(errors.New(cooldownString)) - } - - // Command Logic - var options struct { - Prompt string `discord:"prompt"` - } - - if err := data.Options.Unmarshal(&options); err != nil { - lib.CancelCooldown(data.Event.User.ID.String(), "ask") - return errorResponse(err) - } - - respString, err := lib.ReplicateTextGeneration(options.Prompt) - - if err != nil { - fmt.Printf("ChatCompletion error: %v\n", err) - lib.CancelCooldown(data.Event.User.ID.String(), "ask") - return &api.InteractionResponseData{ - Content: option.NewNullableString("ChatCompletion Error!"), - AllowedMentions: &api.AllowedMentions{}, - } - } - - if len(respString) > 1800 { - textFile := bytes.NewBuffer([]byte(respString)) - - file := sendpart.File{ - Name: "himbot_response.txt", - Reader: textFile, - } - - return &api.InteractionResponseData{ - Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response:\n"), - AllowedMentions: &api.AllowedMentions{}, - Files: []sendpart.File{file}, - } - } - return &api.InteractionResponseData{ - Content: option.NewNullableString("Prompt: " + options.Prompt + "\n" + "Response: " + respString), - AllowedMentions: &api.AllowedMentions{}, - } -} - -func (h *handler) cmdPic(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { - // Cooldown Logic - allowed, cooldownString := lib.CooldownHandler(*data.Event, "pic", time.Minute*2) - - if !allowed { - return errorResponse(errors.New(cooldownString)) - } - - // Command Logic - var options struct { - Prompt string `discord:"prompt"` - } - - if err := data.Options.Unmarshal(&options); err != nil { - lib.CancelCooldown(data.Event.User.ID.String(), "pic") - return errorResponse(err) - } - - // Get current epoch timestamp - timestamp := strconv.FormatInt(time.Now().Unix(), 10) - - // Concatenate clean username and timestamp to form filename - filename := data.Event.Sender().Username + "_" + timestamp + ".jpg" - - imageFile, err := lib.ReplicateImageGeneration(options.Prompt, filename) - - if err != nil { - lib.CancelCooldown(data.Event.User.ID.String(), "pic") - return errorResponse(err) - } - - file := sendpart.File{ - Name: filename, - Reader: imageFile, - } - - return &api.InteractionResponseData{ - Content: option.NewNullableString("Prompt: " + options.Prompt), - Files: []sendpart.File{file}, - } -} - -func (h *handler) cmdHS(ctx context.Context, data cmdroute.CommandData) *api.InteractionResponseData { - var options struct { - Arg string `discord:"nickname"` - } - - if err := data.Options.Unmarshal(&options); err != nil { - return errorResponse(err) - } - - user := lib.GetUserObject(*data.Event) - - return &api.InteractionResponseData{ - Content: option.NewNullableString(options.Arg + " was " + user.DisplayName() + "'s nickname in highschool!"), - } -} - -func errorResponse(err error) *api.InteractionResponseData { - var content string - switch e := err.(type) { - case *net.OpError: - content = "**Network Error:** " + e.Error() - case *os.PathError: - content = "**File Error:** " + e.Error() - default: - content = "**Error:** " + err.Error() - } - - return &api.InteractionResponseData{ - Content: option.NewNullableString(content), - Flags: discord.EphemeralMessage, - AllowedMentions: &api.AllowedMentions{}, - } -}