277 lines
7.0 KiB
Go
277 lines
7.0 KiB
Go
package command
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"fmt"
|
|
"himbot/lib"
|
|
"math/rand"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/bwmarrin/discordgo"
|
|
)
|
|
|
|
// Cache for Markov chains to avoid rebuilding for the same channel/message count
|
|
type MarkovCache struct {
|
|
chains map[string]map[string][]string
|
|
hashes map[string]string
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
var (
|
|
markovCache = &MarkovCache{
|
|
chains: make(map[string]map[string][]string),
|
|
hashes: make(map[string]string),
|
|
}
|
|
// Regex for cleaning text
|
|
urlRegex = regexp.MustCompile(`https?://[^\s]+`)
|
|
mentionRegex = regexp.MustCompile(`<[@#&!][^>]+>`)
|
|
emojiRegex = regexp.MustCompile(`<a?:[^:]+:\d+>`)
|
|
)
|
|
|
|
func MarkovCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
|
|
channelID := i.ChannelID
|
|
|
|
numMessages := lib.AppConfig.MarkovDefaultMessages // Default value from config
|
|
if len(i.ApplicationCommandData().Options) > 0 {
|
|
if i.ApplicationCommandData().Options[0].Name == "messages" {
|
|
numMessages = int(i.ApplicationCommandData().Options[0].IntValue())
|
|
if numMessages <= 0 {
|
|
numMessages = lib.AppConfig.MarkovDefaultMessages
|
|
} else if numMessages > lib.AppConfig.MarkovMaxMessages {
|
|
numMessages = lib.AppConfig.MarkovMaxMessages // Limit from config
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check cache first
|
|
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
|
|
if chain := getCachedChain(cacheKey); chain != nil {
|
|
newMessage := generateMessage(chain)
|
|
if newMessage != "" {
|
|
return newMessage, nil
|
|
}
|
|
}
|
|
|
|
// Fetch messages
|
|
allMessages, err := fetchMessages(s, channelID, numMessages)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Build the Markov chain from the fetched messages
|
|
chain := buildMarkovChain(allMessages)
|
|
|
|
// Cache the chain
|
|
setCachedChain(cacheKey, chain, allMessages)
|
|
|
|
// Generate a new message using the Markov chain
|
|
newMessage := generateMessage(chain)
|
|
|
|
// Check if the generated message is empty and provide a fallback message
|
|
if newMessage == "" {
|
|
newMessage = "I couldn't generate a message. The channel might be empty or contain no usable text."
|
|
}
|
|
|
|
return newMessage, nil
|
|
}
|
|
|
|
func getCachedChain(cacheKey string) map[string][]string {
|
|
markovCache.mu.RLock()
|
|
defer markovCache.mu.RUnlock()
|
|
|
|
if chain, exists := markovCache.chains[cacheKey]; exists {
|
|
return chain
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func setCachedChain(cacheKey string, chain map[string][]string, messages []*discordgo.Message) {
|
|
// Create a hash of the messages to detect changes
|
|
hash := hashMessages(messages)
|
|
|
|
markovCache.mu.Lock()
|
|
defer markovCache.mu.Unlock()
|
|
|
|
// Only cache if we have a meaningful chain
|
|
if len(chain) > 10 {
|
|
markovCache.chains[cacheKey] = chain
|
|
markovCache.hashes[cacheKey] = hash
|
|
|
|
// Simple cache cleanup - keep only last N entries from config
|
|
if len(markovCache.chains) > lib.AppConfig.MarkovCacheSize {
|
|
// Remove oldest entry (simple FIFO)
|
|
for k := range markovCache.chains {
|
|
delete(markovCache.chains, k)
|
|
delete(markovCache.hashes, k)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func hashMessages(messages []*discordgo.Message) string {
|
|
var content strings.Builder
|
|
for _, msg := range messages {
|
|
content.WriteString(msg.ID)
|
|
content.WriteString(msg.Content)
|
|
}
|
|
return fmt.Sprintf("%x", md5.Sum([]byte(content.String())))
|
|
}
|
|
|
|
func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]*discordgo.Message, error) {
|
|
var allMessages []*discordgo.Message
|
|
var lastMessageID string
|
|
|
|
for len(allMessages) < numMessages {
|
|
batchSize := 100
|
|
if numMessages-len(allMessages) < 100 {
|
|
batchSize = numMessages - len(allMessages)
|
|
}
|
|
|
|
batch, err := s.ChannelMessages(channelID, batchSize, lastMessageID, "", "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(batch) == 0 {
|
|
break // No more messages to fetch
|
|
}
|
|
|
|
// Filter out bot messages and empty messages during fetch
|
|
for _, msg := range batch {
|
|
if !msg.Author.Bot && len(strings.TrimSpace(msg.Content)) > 0 {
|
|
allMessages = append(allMessages, msg)
|
|
}
|
|
}
|
|
|
|
lastMessageID = batch[len(batch)-1].ID
|
|
|
|
if len(batch) < 100 {
|
|
break // Less than 100 messages returned, we've reached the end
|
|
}
|
|
}
|
|
|
|
return allMessages, nil
|
|
}
|
|
|
|
// cleanText removes URLs, mentions, emojis, and normalizes text
|
|
func cleanText(text string) string {
|
|
// Remove URLs
|
|
text = urlRegex.ReplaceAllString(text, "")
|
|
// Remove mentions
|
|
text = mentionRegex.ReplaceAllString(text, "")
|
|
// Remove custom emojis
|
|
text = emojiRegex.ReplaceAllString(text, "")
|
|
// Normalize whitespace
|
|
text = strings.Join(strings.Fields(text), " ")
|
|
return strings.TrimSpace(text)
|
|
}
|
|
|
|
// buildMarkovChain creates an improved Markov chain from a list of messages
|
|
func buildMarkovChain(messages []*discordgo.Message) map[string][]string {
|
|
chain := make(map[string][]string)
|
|
|
|
for _, msg := range messages {
|
|
cleanedContent := cleanText(msg.Content)
|
|
if len(cleanedContent) < 3 { // Skip very short messages
|
|
continue
|
|
}
|
|
|
|
words := strings.Fields(cleanedContent)
|
|
if len(words) < 2 { // Need at least 2 words for a chain
|
|
continue
|
|
}
|
|
|
|
// Build the chain by associating each word with the word that follows it
|
|
for i := 0; i < len(words)-1; i++ {
|
|
currentWord := strings.ToLower(words[i])
|
|
nextWord := words[i+1] // Keep original case for next word
|
|
|
|
// Skip very short words or words with special characters
|
|
if len(currentWord) < 2 || strings.ContainsAny(currentWord, "!@#$%^&*()[]{}") {
|
|
continue
|
|
}
|
|
|
|
chain[currentWord] = append(chain[currentWord], nextWord)
|
|
}
|
|
}
|
|
|
|
return chain
|
|
}
|
|
|
|
// generateMessage creates a new message using the Markov chain with improved logic
|
|
func generateMessage(chain map[string][]string) string {
|
|
if len(chain) == 0 {
|
|
return ""
|
|
}
|
|
|
|
words := []string{}
|
|
var currentWord string
|
|
|
|
// Start with a random word that has good follow-ups
|
|
attempts := 0
|
|
for word, nextWords := range chain {
|
|
if len(nextWords) >= 2 && len(word) > 2 { // Prefer words with multiple options
|
|
currentWord = word
|
|
break
|
|
}
|
|
attempts++
|
|
if attempts > 50 { // Fallback to any word
|
|
currentWord = word
|
|
break
|
|
}
|
|
}
|
|
|
|
if currentWord == "" {
|
|
return ""
|
|
}
|
|
|
|
// Generate between 5 and 25 words
|
|
maxWords := 5 + rand.Intn(20)
|
|
for i := 0; i < maxWords; i++ {
|
|
// Add current word (capitalize first word)
|
|
if i == 0 {
|
|
words = append(words, strings.Title(currentWord))
|
|
} else {
|
|
words = append(words, currentWord)
|
|
}
|
|
|
|
if nextWords, ok := chain[strings.ToLower(currentWord)]; ok && len(nextWords) > 0 {
|
|
// Randomly select the next word from the possible follow-ups
|
|
currentWord = nextWords[rand.Intn(len(nextWords))]
|
|
} else {
|
|
// Try to find a new starting point
|
|
found := false
|
|
for word, nextWords := range chain {
|
|
if len(nextWords) > 0 && len(word) > 2 {
|
|
currentWord = word
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
result := strings.Join(words, " ")
|
|
|
|
// Add punctuation if missing
|
|
if len(result) > 0 && !strings.ContainsAny(result[len(result)-1:], ".!?") {
|
|
// Randomly add punctuation
|
|
punctuation := []string{".", "!", "?"}
|
|
result += punctuation[rand.Intn(len(punctuation))]
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func init() {
|
|
// Seed random number generator
|
|
rand.Seed(time.Now().UnixNano())
|
|
}
|