Files
himbot/command/markov.go
Atridad Lahiji 0081978489
All checks were successful
Docker Deploy / build-and-push (push) Successful in 3m22s
Made it more MaRkOvIaN!!!!!!
2026-01-19 22:33:55 -07:00

292 lines
6.4 KiB
Go

package command
import (
"crypto/md5"
"fmt"
"himbot/lib"
"math/rand"
"regexp"
"strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
)
type MarkovData struct {
Chain map[string][]string // "word1 word2" -> ["word3", ...]
Starts []string
}
type MarkovCache struct {
data map[string]*MarkovData
hashes map[string]string
mu sync.RWMutex
}
var (
markovCache = &MarkovCache{
data: make(map[string]*MarkovData),
hashes: make(map[string]string),
}
urlRegex = regexp.MustCompile(`https?://[^\s]+`)
mentionRegex = regexp.MustCompile(`<[@#&!][^>]+>`)
)
func MarkovCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
channelID := i.ChannelID
numMessages := lib.AppConfig.MarkovDefaultMessages
if len(i.ApplicationCommandData().Options) > 0 {
if i.ApplicationCommandData().Options[0].Name == "messages" {
numMessages = int(i.ApplicationCommandData().Options[0].IntValue())
if numMessages <= 0 {
numMessages = lib.AppConfig.MarkovDefaultMessages
} else if numMessages > lib.AppConfig.MarkovMaxMessages {
numMessages = lib.AppConfig.MarkovMaxMessages
}
}
}
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
if data := getCachedChain(cacheKey); data != nil {
if msg := generateMessage(data, ""); msg != "" {
return msg, nil
}
}
allMessages, err := fetchMessages(s, channelID, numMessages)
if err != nil {
return "", err
}
data := buildMarkovChain(allMessages)
setCachedChain(cacheKey, data, allMessages)
newMessage := generateMessage(data, "")
if newMessage == "" {
newMessage = "Not enough text data to generate a message."
}
return newMessage, nil
}
func MarkovQuestionCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
channelID := i.ChannelID
var question string
numMessages := lib.AppConfig.MarkovDefaultMessages
for _, option := range i.ApplicationCommandData().Options {
switch option.Name {
case "question":
question = option.StringValue()
case "messages":
numMessages = int(option.IntValue())
if numMessages <= 0 {
numMessages = lib.AppConfig.MarkovDefaultMessages
} else if numMessages > lib.AppConfig.MarkovMaxMessages {
numMessages = lib.AppConfig.MarkovMaxMessages
}
}
}
if question == "" {
return "Please provide a question!", nil
}
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
var data *MarkovData
if cachedData := getCachedChain(cacheKey); cachedData != nil {
data = cachedData
} else {
allMessages, err := fetchMessages(s, channelID, numMessages)
if err != nil {
return "", err
}
data = buildMarkovChain(allMessages)
setCachedChain(cacheKey, data, allMessages)
}
answer := generateMessage(data, question)
if answer == "" {
answer = "I don't have enough context to answer that."
}
return fmt.Sprintf("**Q:** %s\n**A:** %s", question, answer), nil
}
func getCachedChain(cacheKey string) *MarkovData {
markovCache.mu.RLock()
defer markovCache.mu.RUnlock()
return markovCache.data[cacheKey]
}
func setCachedChain(cacheKey string, data *MarkovData, messages []*discordgo.Message) {
hash := hashMessages(messages)
markovCache.mu.Lock()
defer markovCache.mu.Unlock()
if len(data.Starts) > 0 {
markovCache.data[cacheKey] = data
markovCache.hashes[cacheKey] = hash
if len(markovCache.data) > lib.AppConfig.MarkovCacheSize {
for k := range markovCache.data {
delete(markovCache.data, k)
delete(markovCache.hashes, k)
break
}
}
}
}
func hashMessages(messages []*discordgo.Message) string {
var content strings.Builder
for _, msg := range messages {
content.WriteString(msg.ID)
}
return fmt.Sprintf("%x", md5.Sum([]byte(content.String())))
}
func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]*discordgo.Message, error) {
var allMessages []*discordgo.Message
var lastMessageID string
allMessages = make([]*discordgo.Message, 0, numMessages)
for len(allMessages) < numMessages {
batchSize := 100
if numMessages-len(allMessages) < 100 {
batchSize = numMessages - len(allMessages)
}
batch, err := s.ChannelMessages(channelID, batchSize, lastMessageID, "", "")
if err != nil {
return nil, err
}
if len(batch) == 0 {
break
}
for _, msg := range batch {
if !msg.Author.Bot && len(strings.TrimSpace(msg.Content)) > 0 {
allMessages = append(allMessages, msg)
}
}
lastMessageID = batch[len(batch)-1].ID
if len(batch) < 100 {
break
}
}
return allMessages, nil
}
func cleanText(text string) string {
text = urlRegex.ReplaceAllString(text, "")
text = mentionRegex.ReplaceAllString(text, "")
return strings.Join(strings.Fields(text), " ")
}
func buildMarkovChain(messages []*discordgo.Message) *MarkovData {
data := &MarkovData{
Chain: make(map[string][]string),
Starts: make([]string, 0),
}
for _, msg := range messages {
cleaned := cleanText(msg.Content)
if cleaned == "" {
continue
}
words := strings.Fields(cleaned)
if len(words) < 3 {
continue
}
startKey := key(words[0], words[1])
data.Starts = append(data.Starts, startKey)
for i := 0; i < len(words)-2; i++ {
k := key(words[i], words[i+1])
val := words[i+2]
data.Chain[k] = append(data.Chain[k], val)
}
}
return data
}
func generateMessage(data *MarkovData, seed string) string {
if len(data.Starts) == 0 {
return ""
}
var w1, w2 string
var currentKey string
// Try to seed based on input question
if seed != "" {
seedWords := strings.Fields(cleanText(seed))
var candidates []string
for k := range data.Chain {
for _, sw := range seedWords {
if len(sw) > 3 && strings.Contains(strings.ToLower(k), strings.ToLower(sw)) {
candidates = append(candidates, k)
}
}
}
if len(candidates) > 0 {
currentKey = candidates[rand.Intn(len(candidates))]
}
}
if currentKey == "" {
currentKey = data.Starts[rand.Intn(len(data.Starts))]
}
parts := strings.Split(currentKey, " ")
w1, w2 = parts[0], parts[1]
output := []string{w1, w2}
for i := 0; i < 40; i++ {
nextOptions, exists := data.Chain[currentKey]
if !exists || len(nextOptions) == 0 {
break
}
nextWord := nextOptions[rand.Intn(len(nextOptions))]
output = append(output, nextWord)
w1 = w2
w2 = nextWord
currentKey = key(w1, w2)
// Soft stop on punctuation
if i > 5 && strings.ContainsAny(nextWord, ".!?") {
if rand.Float32() > 0.3 {
break
}
}
}
return strings.Join(output, " ")
}
func key(w1, w2 string) string {
return w1 + " " + w2
}
func init() {
rand.Seed(time.Now().UnixNano())
}