Files
himbot/command/markov.go
Atridad Lahiji e7adc80bc4
All checks were successful
Docker Deploy / build-and-push (push) Successful in 3m16s
Updated deps + cleaned up mchain
2025-12-22 23:36:58 -07:00

952 lines
23 KiB
Go

package command
import (
"crypto/md5"
"fmt"
"himbot/lib"
"math/rand"
"regexp"
"strings"
"sync"
"time"
"github.com/bwmarrin/discordgo"
)
// MarkovData holds the Markov chain data for different n-gram sizes
type MarkovData struct {
// n-gram size -> prefix -> list of suffixes
Chains map[int]map[string][]string
}
// MarkovCache caches chains to avoid rebuilding
type MarkovCache struct {
data map[string]*MarkovData
hashes map[string]string
mu sync.RWMutex
}
var (
markovCache = &MarkovCache{
data: make(map[string]*MarkovData),
hashes: make(map[string]string),
}
// Regex for cleaning text
urlRegex = regexp.MustCompile(`https?://[^\s]+`)
mentionRegex = regexp.MustCompile(`<[@#&!][^>]+>`)
emojiRegex = regexp.MustCompile(`<a?:[^:]+:\d+>`)
)
func MarkovCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
channelID := i.ChannelID
numMessages := lib.AppConfig.MarkovDefaultMessages
if len(i.ApplicationCommandData().Options) > 0 {
if i.ApplicationCommandData().Options[0].Name == "messages" {
numMessages = int(i.ApplicationCommandData().Options[0].IntValue())
if numMessages <= 0 {
numMessages = lib.AppConfig.MarkovDefaultMessages
} else if numMessages > lib.AppConfig.MarkovMaxMessages {
numMessages = lib.AppConfig.MarkovMaxMessages
}
}
}
// Check cache
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
if data := getCachedChain(cacheKey); data != nil {
newMessage := generateAdvancedMessage(data)
if newMessage != "" {
return newMessage, nil
}
}
// Fetch messages
allMessages, err := fetchMessages(s, channelID, numMessages)
if err != nil {
return "", err
}
// Build chain
data := buildMarkovChain(allMessages)
// Cache chain
setCachedChain(cacheKey, data, allMessages)
// Generate message
newMessage := generateAdvancedMessage(data)
// Fallback if empty
if newMessage == "" {
newMessage = "I couldn't generate a message. The channel might be empty or contain no usable text."
}
return newMessage, nil
}
func getCachedChain(cacheKey string) *MarkovData {
markovCache.mu.RLock()
defer markovCache.mu.RUnlock()
if data, exists := markovCache.data[cacheKey]; exists {
return data
}
return nil
}
func setCachedChain(cacheKey string, data *MarkovData, messages []*discordgo.Message) {
hash := hashMessages(messages)
markovCache.mu.Lock()
defer markovCache.mu.Unlock()
// Only cache if we have some data
if len(data.Chains[1]) > 10 {
markovCache.data[cacheKey] = data
markovCache.hashes[cacheKey] = hash
// Simple FIFO cache cleanup
if len(markovCache.data) > lib.AppConfig.MarkovCacheSize {
for k := range markovCache.data {
delete(markovCache.data, k)
delete(markovCache.hashes, k)
break
}
}
}
}
func hashMessages(messages []*discordgo.Message) string {
var content strings.Builder
for _, msg := range messages {
content.WriteString(msg.ID)
content.WriteString(msg.Content)
}
return fmt.Sprintf("%x", md5.Sum([]byte(content.String())))
}
func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]*discordgo.Message, error) {
var allMessages []*discordgo.Message
var lastMessageID string
// Pre-allocate
allMessages = make([]*discordgo.Message, 0, numMessages)
for len(allMessages) < numMessages {
batchSize := 100
if numMessages-len(allMessages) < 100 {
batchSize = numMessages - len(allMessages)
}
batch, err := s.ChannelMessages(channelID, batchSize, lastMessageID, "", "")
if err != nil {
return nil, err
}
if len(batch) == 0 {
break
}
// Filter messages
for _, msg := range batch {
if !msg.Author.Bot && len(strings.TrimSpace(msg.Content)) > 0 {
allMessages = append(allMessages, msg)
}
}
lastMessageID = batch[len(batch)-1].ID
if len(batch) < 100 {
break
}
}
return allMessages, nil
}
// cleanText normalizes text
func cleanText(text string) string {
text = urlRegex.ReplaceAllString(text, "")
text = mentionRegex.ReplaceAllString(text, "")
text = emojiRegex.ReplaceAllString(text, "")
text = strings.Join(strings.Fields(text), " ")
return strings.TrimSpace(text)
}
// buildMarkovChain creates a Markov chain from messages
func buildMarkovChain(messages []*discordgo.Message) *MarkovData {
data := &MarkovData{
Chains: make(map[int]map[string][]string),
}
// Count words
totalWords := 0
for _, msg := range messages {
cleanedContent := cleanText(msg.Content)
if len(cleanedContent) >= 3 {
words := strings.Fields(cleanedContent)
totalWords += len(words)
}
}
// Adjust n-gram level based on memory
maxNGram := lib.AppConfig.MarkovMaxNGram
estimatedMemoryMB := estimateMemoryUsage(totalWords, maxNGram)
if estimatedMemoryMB > lib.AppConfig.MarkovMemoryLimit {
for maxNGram > 2 && estimateMemoryUsage(totalWords, maxNGram) > lib.AppConfig.MarkovMemoryLimit {
maxNGram--
}
}
// Init maps
for i := 1; i <= maxNGram; i++ {
data.Chains[i] = make(map[string][]string)
}
for _, msg := range messages {
cleanedContent := cleanText(msg.Content)
if len(cleanedContent) < 3 {
continue
}
words := strings.Fields(cleanedContent)
if len(words) < 2 {
continue
}
// Build chains
for n := 1; n <= maxNGram; n++ {
if len(words) <= n {
continue
}
for i := 0; i < len(words)-n; i++ {
// Validate sequence
validSequence := true
for j := 0; j < n; j++ {
word := words[i+j]
if len(word) < 2 || strings.ContainsAny(word, "!@#$%^&*()[]{}") {
validSequence = false
break
}
}
if !validSequence {
continue
}
// Build prefix
var prefixBuilder strings.Builder
for j := 0; j < n; j++ {
if j > 0 {
prefixBuilder.WriteString(" ")
}
prefixBuilder.WriteString(strings.ToLower(words[i+j]))
}
prefix := prefixBuilder.String()
nextWord := words[i+n]
data.Chains[n][prefix] = append(data.Chains[n][prefix], nextWord)
}
}
}
return data
}
// estimateMemoryUsage estimates memory usage in MB
func estimateMemoryUsage(wordCount int, maxNGram int) int {
baseMB := wordCount / 2000
switch maxNGram {
case 2:
return baseMB * 3
case 3:
return baseMB * 8
case 4:
return baseMB * 15
case 5:
return baseMB * 25
case 6:
return baseMB * 40
default:
return baseMB
}
}
func init() {
// Seed RNG
rand.Seed(time.Now().UnixNano())
}
// MarkovQuestionCommand generates an answer
func MarkovQuestionCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
channelID := i.ChannelID
var question string
var numMessages int = lib.AppConfig.MarkovDefaultMessages
for _, option := range i.ApplicationCommandData().Options {
switch option.Name {
case "question":
question = option.StringValue()
case "messages":
numMessages = int(option.IntValue())
if numMessages <= 0 {
numMessages = lib.AppConfig.MarkovDefaultMessages
} else if numMessages > lib.AppConfig.MarkovMaxMessages {
numMessages = lib.AppConfig.MarkovMaxMessages
}
}
}
if question == "" {
return "Please provide a question!", nil
}
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
var data *MarkovData
if cachedData := getCachedChain(cacheKey); cachedData != nil {
data = cachedData
} else {
allMessages, err := fetchMessages(s, channelID, numMessages)
if err != nil {
return "", err
}
data = buildMarkovChain(allMessages)
setCachedChain(cacheKey, data, allMessages)
}
answer := generateAdvancedQuestionAnswer(data, question)
if answer == "" {
answer = "I couldn't generate an answer to that question. The channel might not have enough relevant content."
}
return fmt.Sprintf("**Q:** %s\n**A:** %s", question, answer), nil
}
// generateQuestionAnswer generates answer
func generateQuestionAnswer(data *MarkovData, question string) string {
return generateAdvancedQuestionAnswer(data, question)
}
// categorizeQuestion determines question type
func categorizeQuestion(question string) string {
question = strings.ToLower(question)
if strings.Contains(question, "what") {
return "what"
} else if strings.Contains(question, "how") {
return "how"
} else if strings.Contains(question, "why") {
return "why"
} else if strings.Contains(question, "when") {
return "when"
} else if strings.Contains(question, "where") {
return "where"
} else if strings.Contains(question, "who") {
return "who"
} else if strings.Contains(question, "which") {
return "which"
} else if strings.Contains(question, "is") || strings.Contains(question, "are") || strings.Contains(question, "do") || strings.Contains(question, "does") {
return "yesno"
}
return "general"
}
// WordCandidate holds word score
type WordCandidate struct {
Word string
Score int
}
// findBestStartingWords scores starting words
func findBestStartingWords(data *MarkovData, questionWords []string, questionType string) []WordCandidate {
candidates := make(map[string]int)
chain := data.Chains[1]
// Score question words
for _, word := range questionWords {
if len(word) > 2 && !isStopWord(word) {
if nextWords, exists := chain[word]; exists && len(nextWords) > 0 {
candidates[word] += 10
}
}
}
// Add context words
contextWords := getContextualWords(questionType)
for _, word := range contextWords {
if nextWords, exists := chain[word]; exists && len(nextWords) > 0 {
candidates[word] += 5
}
}
// Add fallback words
for word, nextWords := range chain {
if len(nextWords) >= 3 && len(word) > 2 && !isStopWord(word) {
if _, exists := candidates[word]; !exists {
candidates[word] = len(nextWords) / 2
}
}
}
// Sort candidates
var result []WordCandidate
for word, score := range candidates {
result = append(result, WordCandidate{Word: word, Score: score})
}
for i := 0; i < len(result)-1; i++ {
for j := i + 1; j < len(result); j++ {
if result[j].Score > result[i].Score {
result[i], result[j] = result[j], result[i]
}
}
}
// Top 10
if len(result) > 10 {
result = result[:10]
}
return result
}
// getContextualWords returns relevant words
func getContextualWords(questionType string) []string {
switch questionType {
case "what":
return []string{"thing", "something", "object", "idea", "concept", "stuff", "item"}
case "how":
return []string{"way", "method", "process", "steps", "technique", "approach"}
case "why":
return []string{"because", "reason", "cause", "since", "due", "explanation"}
case "when":
return []string{"time", "moment", "day", "hour", "yesterday", "today", "tomorrow", "now", "then"}
case "where":
return []string{"place", "location", "here", "there", "somewhere", "anywhere"}
case "who":
return []string{"person", "people", "someone", "anyone", "everybody", "nobody"}
case "which":
return []string{"choice", "option", "selection", "pick", "prefer"}
case "yesno":
return []string{"yes", "no", "maybe", "definitely", "probably", "possibly", "sure", "absolutely"}
default:
return []string{"think", "believe", "know", "understand", "feel", "seem"}
}
}
// scoreResponse scores the response
func scoreResponse(response string, questionType string) int {
score := 0
words := strings.Fields(response)
// Length score
if len(words) >= 8 && len(words) <= 16 {
score += 10
} else if len(words) >= 6 && len(words) <= 20 {
score += 5
}
// Diversity score
totalLength := 0
for _, word := range words {
totalLength += len(word)
}
if len(words) > 0 {
avgWordLength := float64(totalLength) / float64(len(words))
if avgWordLength > 3.5 && avgWordLength < 6.0 {
score += 5
}
}
// Content score
contentWords := 0
for _, word := range words {
if len(word) > 3 && !isStopWord(strings.ToLower(word)) {
contentWords++
}
}
score += contentWords
return score
}
// getPunctuationForQuestionType returns punctuation
func getPunctuationForQuestionType(questionType string) []string {
switch questionType {
case "yesno":
return []string{".", "!", "."}
case "why", "how":
return []string{".", ".", "!"}
default:
return []string{".", ".", "!", "."}
}
}
// min helper
func min(a, b int) int {
if a < b {
return a
}
return b
}
// isStopWord checks for common words
func isStopWord(word string) bool {
stopWords := map[string]bool{
"a": true, "an": true, "and": true, "are": true, "as": true, "at": true, "be": true, "by": true,
"for": true, "from": true, "has": true, "he": true, "in": true, "is": true, "it": true,
"its": true, "of": true, "on": true, "that": true, "the": true, "to": true, "was": true,
"will": true, "with": true, "or": true, "but": true, "if": true, "so": true, "do": true,
}
return stopWords[word]
}
// generateAdvancedMessage generates a message
func generateAdvancedMessage(data *MarkovData) string {
if len(data.Chains[1]) == 0 {
return ""
}
// Try multiple attempts
var bestMessage string
bestScore := 0
for attempt := 0; attempt < 5; attempt++ {
words := []string{}
var currentWord string
// Pick start word
attempts := 0
for word, nextWords := range data.Chains[1] {
if len(nextWords) >= 2 && len(word) > 2 && !isStopWord(word) {
currentWord = word
break
}
attempts++
if attempts > 50 {
currentWord = word
break
}
}
if currentWord == "" {
continue
}
// Generate words
maxWords := 10 + rand.Intn(8)
wordHistory := []string{currentWord}
for i := 0; i < maxWords; i++ {
// Add word
if i == 0 {
words = append(words, strings.Title(currentWord))
} else {
words = append(words, currentWord)
}
var nextWord string
historyLen := len(wordHistory)
// Try n-grams
for n := 5; n >= 2; n-- {
if historyLen >= n && data.Chains[n] != nil {
// Build prefix
var prefixBuilder strings.Builder
for j := 0; j < n; j++ {
if j > 0 {
prefixBuilder.WriteString(" ")
}
prefixBuilder.WriteString(strings.ToLower(wordHistory[historyLen-n+j]))
}
prefix := prefixBuilder.String()
if options, exists := data.Chains[n][prefix]; exists && len(options) > 0 {
nextWord = selectBestNextWord(options, wordHistory)
if nextWord != "" {
break
}
}
}
}
// Fallback to 1-gram
if nextWord == "" {
if nextWords, exists := data.Chains[1][strings.ToLower(currentWord)]; exists && len(nextWords) > 0 {
nextWord = selectBestNextWord(nextWords, wordHistory)
}
}
// Restart if needed
if nextWord == "" {
found := false
for word, nextWords := range data.Chains[1] {
if len(nextWords) > 0 && len(word) > 2 && !isStopWord(word) {
nextWord = word
found = true
break
}
}
if !found {
break
}
}
currentWord = nextWord
wordHistory = append(wordHistory, currentWord)
// Trim history
if len(wordHistory) > 10 {
wordHistory = wordHistory[1:]
}
}
message := strings.Join(words, " ")
// Score message
score := scoreAdvancedMessage(message)
if score > bestScore {
bestScore = score
bestMessage = message
}
}
// Add punctuation
if len(bestMessage) > 0 && !strings.ContainsAny(bestMessage[len(bestMessage)-1:], ".!?") {
punctuation := []string{".", "!", "?", "."}
bestMessage += punctuation[rand.Intn(len(punctuation))]
}
return bestMessage
}
// generateAdvancedQuestionAnswer generates answer
func generateAdvancedQuestionAnswer(data *MarkovData, question string) string {
if len(data.Chains[1]) == 0 {
return ""
}
// Analyze question
cleanedQuestion := cleanText(question)
questionWords := strings.Fields(strings.ToLower(cleanedQuestion))
// Categorize question
questionType := categorizeQuestion(cleanedQuestion)
// Find starting words
startingCandidates := findBestStartingWords(data, questionWords, questionType)
if len(startingCandidates) == 0 {
return ""
}
// Generate response
return generateAdvancedCoherentResponse(data, startingCandidates, questionType)
}
// generateAdvancedCoherentResponse generates response
func generateAdvancedCoherentResponse(data *MarkovData, candidates []WordCandidate, questionType string) string {
if len(candidates) == 0 {
return ""
}
// Try multiple attempts
var bestResponse string
bestScore := 0
for attempt := 0; attempt < 5; attempt++ {
// Pick candidate
candidateIndex := 0
if len(candidates) > 1 {
if rand.Float32() > 0.7 && len(candidates) > 1 {
candidateIndex = rand.Intn(min(3, len(candidates)))
}
}
currentWord := candidates[candidateIndex].Word
words := []string{}
wordHistory := []string{currentWord}
// Generate response
maxWords := 12 + rand.Intn(10)
for i := 0; i < maxWords; i++ {
// Add word
if i == 0 {
words = append(words, strings.Title(currentWord))
} else {
words = append(words, currentWord)
}
var nextWord string
historyLen := len(wordHistory)
// Try n-grams
for n := 5; n >= 2; n-- {
if historyLen >= n && data.Chains[n] != nil {
// Build prefix
var prefixBuilder strings.Builder
for j := 0; j < n; j++ {
if j > 0 {
prefixBuilder.WriteString(" ")
}
prefixBuilder.WriteString(strings.ToLower(wordHistory[historyLen-n+j]))
}
prefix := prefixBuilder.String()
if options, exists := data.Chains[n][prefix]; exists && len(options) > 0 {
nextWord = selectBestNextWord(options, wordHistory)
if nextWord != "" {
break
}
}
}
}
// Fallback
if nextWord == "" {
if nextWords, exists := data.Chains[1][strings.ToLower(currentWord)]; exists && len(nextWords) > 0 {
nextWord = selectBestNextWord(nextWords, wordHistory)
}
}
// Restart
if nextWord == "" {
found := false
for _, candidate := range candidates {
if nextWords, exists := data.Chains[1][candidate.Word]; exists && len(nextWords) > 0 {
nextWord = candidate.Word
found = true
break
}
}
if !found {
break
}
}
currentWord = nextWord
wordHistory = append(wordHistory, currentWord)
// Trim history
if len(wordHistory) > 10 {
wordHistory = wordHistory[1:]
}
}
response := strings.Join(words, " ")
// Score response
score := scoreAdvancedResponse(response, questionType)
if score > bestScore {
bestScore = score
bestResponse = response
}
}
// Add punctuation
if len(bestResponse) > 0 && !strings.ContainsAny(bestResponse[len(bestResponse)-1:], ".!?") {
punctuation := getPunctuationForQuestionType(questionType)
bestResponse += punctuation[rand.Intn(len(punctuation))]
}
return bestResponse
}
// scoreAdvancedMessage scores message
func scoreAdvancedMessage(message string) int {
score := 0
words := strings.Fields(message)
// Length score
if len(words) >= 10 && len(words) <= 16 {
score += 15
} else if len(words) >= 8 && len(words) <= 18 {
score += 10
} else if len(words) >= 6 && len(words) <= 20 {
score += 5
}
// Diversity score
totalLength := 0
uniqueWords := make(map[string]bool)
for _, word := range words {
totalLength += len(word)
uniqueWords[strings.ToLower(word)] = true
}
if len(words) > 0 {
avgWordLength := float64(totalLength) / float64(len(words))
if avgWordLength > 3.5 && avgWordLength < 6.5 {
score += 8
}
// Uniqueness score
uniqueRatio := float64(len(uniqueWords)) / float64(len(words))
if uniqueRatio > 0.8 {
score += 10
} else if uniqueRatio > 0.6 {
score += 5
}
}
// Content score
contentWords := 0
for _, word := range words {
if len(word) > 3 && !isStopWord(strings.ToLower(word)) {
contentWords++
}
}
score += contentWords * 2
// Grammar bonus
if !strings.Contains(message, " a a ") && !strings.Contains(message, " the the ") && !strings.Contains(message, " you you ") {
score += 5
}
return score
}
// scoreAdvancedResponse scores response
func scoreAdvancedResponse(response string, questionType string) int {
score := scoreAdvancedMessage(response) // Base score
// Question bonuses
responseLower := strings.ToLower(response)
switch questionType {
case "yesno":
if strings.Contains(responseLower, "yes") || strings.Contains(responseLower, "no") ||
strings.Contains(responseLower, "maybe") || strings.Contains(responseLower, "definitely") {
score += 8
}
case "why":
if strings.Contains(responseLower, "because") || strings.Contains(responseLower, "reason") ||
strings.Contains(responseLower, "since") || strings.Contains(responseLower, "due") {
score += 8
}
case "how":
if strings.Contains(responseLower, "way") || strings.Contains(responseLower, "method") ||
strings.Contains(responseLower, "process") || strings.Contains(responseLower, "steps") {
score += 8
}
case "when":
if strings.Contains(responseLower, "time") || strings.Contains(responseLower, "day") ||
strings.Contains(responseLower, "hour") || strings.Contains(responseLower, "moment") {
score += 8
}
case "where":
if strings.Contains(responseLower, "place") || strings.Contains(responseLower, "location") ||
strings.Contains(responseLower, "here") || strings.Contains(responseLower, "there") {
score += 8
}
}
return score
}
// isValidNextWord checks validity
func isValidNextWord(wordHistory []string, nextWord string) bool {
if len(wordHistory) == 0 {
return true
}
nextWordLower := strings.ToLower(nextWord)
// No immediate repetition
if len(wordHistory) >= 1 && strings.ToLower(wordHistory[len(wordHistory)-1]) == nextWordLower {
return false
}
// No double articles
if len(wordHistory) >= 1 {
lastWord := strings.ToLower(wordHistory[len(wordHistory)-1])
if (lastWord == "a" || lastWord == "the" || lastWord == "you") && lastWord == nextWordLower {
return false
}
}
// No triple repetition
if len(wordHistory) >= 3 {
count := 0
for i := len(wordHistory) - 3; i < len(wordHistory); i++ {
if strings.ToLower(wordHistory[i]) == nextWordLower {
count++
}
}
if count >= 2 {
return false
}
}
// Grammar checks
if len(wordHistory) >= 1 {
lastWord := strings.ToLower(wordHistory[len(wordHistory)-1])
// No "you a"
if lastWord == "you" && nextWordLower == "a" {
return false
}
// No double articles
if (lastWord == "a" || lastWord == "an" || lastWord == "the") &&
(nextWordLower == "a" || nextWordLower == "an" || nextWordLower == "the") {
return false
}
}
return true
}
// selectBestNextWord picks next word
func selectBestNextWord(options []string, wordHistory []string) string {
if len(options) == 0 {
return ""
}
// Filter invalid
var validOptions []string
for _, option := range options {
if isValidNextWord(wordHistory, option) {
validOptions = append(validOptions, option)
}
}
// Fallback
if len(validOptions) == 0 {
var fallbackOptions []string
for _, option := range options {
// Avoid repetition
if len(wordHistory) == 0 || strings.ToLower(wordHistory[len(wordHistory)-1]) != strings.ToLower(option) {
fallbackOptions = append(fallbackOptions, option)
}
}
if len(fallbackOptions) > 0 {
validOptions = fallbackOptions
} else {
validOptions = options
}
}
// Prefer meaningful words
var goodOptions []string
for _, option := range validOptions {
if len(option) > 2 && !isStopWord(strings.ToLower(option)) {
goodOptions = append(goodOptions, option)
}
}
if len(goodOptions) > 0 {
return goodOptions[rand.Intn(len(goodOptions))]
}
return validOptions[rand.Intn(len(validOptions))]
}