1472 lines
45 KiB
Go
1472 lines
45 KiB
Go
package command
|
|
|
|
import (
|
|
"crypto/md5"
|
|
"fmt"
|
|
"himbot/lib"
|
|
"math/rand"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/bwmarrin/discordgo"
|
|
)
|
|
|
|
// Cache for Markov chains to avoid rebuilding for the same channel/message count
|
|
type MarkovCache struct {
|
|
chains map[string]map[string][]string
|
|
twoGrams map[string]map[string]map[string][]string
|
|
threeGrams map[string]map[string]map[string]map[string][]string
|
|
fourGrams map[string]map[string]map[string]map[string]map[string][]string
|
|
fiveGrams map[string]map[string]map[string]map[string]map[string]map[string][]string
|
|
hashes map[string]string
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
var (
|
|
markovCache = &MarkovCache{
|
|
chains: make(map[string]map[string][]string),
|
|
twoGrams: make(map[string]map[string]map[string][]string),
|
|
threeGrams: make(map[string]map[string]map[string]map[string][]string),
|
|
fourGrams: make(map[string]map[string]map[string]map[string]map[string][]string),
|
|
fiveGrams: make(map[string]map[string]map[string]map[string]map[string]map[string][]string),
|
|
hashes: make(map[string]string),
|
|
}
|
|
// Regex for cleaning text
|
|
urlRegex = regexp.MustCompile(`https?://[^\s]+`)
|
|
mentionRegex = regexp.MustCompile(`<[@#&!][^>]+>`)
|
|
emojiRegex = regexp.MustCompile(`<a?:[^:]+:\d+>`)
|
|
)
|
|
|
|
func MarkovCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
|
|
channelID := i.ChannelID
|
|
|
|
numMessages := lib.AppConfig.MarkovDefaultMessages // Default value from config
|
|
if len(i.ApplicationCommandData().Options) > 0 {
|
|
if i.ApplicationCommandData().Options[0].Name == "messages" {
|
|
numMessages = int(i.ApplicationCommandData().Options[0].IntValue())
|
|
if numMessages <= 0 {
|
|
numMessages = lib.AppConfig.MarkovDefaultMessages
|
|
} else if numMessages > lib.AppConfig.MarkovMaxMessages {
|
|
numMessages = lib.AppConfig.MarkovMaxMessages // Limit from config
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check cache first
|
|
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
|
|
if chain := getCachedChain(cacheKey); chain != nil {
|
|
newMessage := generateMessage(chain)
|
|
if newMessage != "" {
|
|
return newMessage, nil
|
|
}
|
|
}
|
|
|
|
// Fetch messages
|
|
allMessages, err := fetchMessages(s, channelID, numMessages)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Build the Markov chain from the fetched messages
|
|
chain, twoGramChain, threeGramChain, fourGramChain, fiveGramChain := buildMarkovChain(allMessages)
|
|
|
|
// Cache the chain
|
|
setCachedChain(cacheKey, chain, twoGramChain, threeGramChain, fourGramChain, fiveGramChain, allMessages)
|
|
|
|
// Generate a new message using the improved Markov chain
|
|
newMessage := generateAdvancedMessage(chain, twoGramChain, threeGramChain, fourGramChain, fiveGramChain)
|
|
|
|
// Check if the generated message is empty and provide a fallback message
|
|
if newMessage == "" {
|
|
newMessage = "I couldn't generate a message. The channel might be empty or contain no usable text."
|
|
}
|
|
|
|
return newMessage, nil
|
|
}
|
|
|
|
func getCachedChain(cacheKey string) map[string][]string {
|
|
markovCache.mu.RLock()
|
|
defer markovCache.mu.RUnlock()
|
|
|
|
if chain, exists := markovCache.chains[cacheKey]; exists {
|
|
return chain
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getCachedTwoGramChain(cacheKey string) map[string]map[string][]string {
|
|
markovCache.mu.RLock()
|
|
defer markovCache.mu.RUnlock()
|
|
|
|
if twoGram, exists := markovCache.twoGrams[cacheKey]; exists {
|
|
return twoGram
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getCachedThreeGramChain(cacheKey string) map[string]map[string]map[string][]string {
|
|
markovCache.mu.RLock()
|
|
defer markovCache.mu.RUnlock()
|
|
|
|
if threeGram, exists := markovCache.threeGrams[cacheKey]; exists {
|
|
return threeGram
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getCachedFourGramChain(cacheKey string) map[string]map[string]map[string]map[string][]string {
|
|
markovCache.mu.RLock()
|
|
defer markovCache.mu.RUnlock()
|
|
|
|
if fourGram, exists := markovCache.fourGrams[cacheKey]; exists {
|
|
return fourGram
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getCachedFiveGramChain(cacheKey string) map[string]map[string]map[string]map[string]map[string][]string {
|
|
markovCache.mu.RLock()
|
|
defer markovCache.mu.RUnlock()
|
|
|
|
if fiveGram, exists := markovCache.fiveGrams[cacheKey]; exists {
|
|
return fiveGram
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func setCachedChain(cacheKey string, chain map[string][]string, twoGramChain map[string]map[string][]string, threeGramChain map[string]map[string]map[string][]string, fourGramChain map[string]map[string]map[string]map[string][]string, fiveGramChain map[string]map[string]map[string]map[string]map[string][]string, messages []*discordgo.Message) {
|
|
hash := hashMessages(messages)
|
|
|
|
markovCache.mu.Lock()
|
|
defer markovCache.mu.Unlock()
|
|
|
|
if len(chain) > 10 {
|
|
markovCache.chains[cacheKey] = chain
|
|
markovCache.twoGrams[cacheKey] = twoGramChain
|
|
markovCache.threeGrams[cacheKey] = threeGramChain
|
|
markovCache.fourGrams[cacheKey] = fourGramChain
|
|
markovCache.fiveGrams[cacheKey] = fiveGramChain
|
|
markovCache.hashes[cacheKey] = hash
|
|
|
|
// Simple FIFO cache cleanup
|
|
if len(markovCache.chains) > lib.AppConfig.MarkovCacheSize {
|
|
for k := range markovCache.chains {
|
|
delete(markovCache.chains, k)
|
|
delete(markovCache.twoGrams, k)
|
|
delete(markovCache.threeGrams, k)
|
|
delete(markovCache.fourGrams, k)
|
|
delete(markovCache.fiveGrams, k)
|
|
delete(markovCache.hashes, k)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func hashMessages(messages []*discordgo.Message) string {
|
|
var content strings.Builder
|
|
for _, msg := range messages {
|
|
content.WriteString(msg.ID)
|
|
content.WriteString(msg.Content)
|
|
}
|
|
return fmt.Sprintf("%x", md5.Sum([]byte(content.String())))
|
|
}
|
|
|
|
func fetchMessages(s *discordgo.Session, channelID string, numMessages int) ([]*discordgo.Message, error) {
|
|
var allMessages []*discordgo.Message
|
|
var lastMessageID string
|
|
|
|
for len(allMessages) < numMessages {
|
|
batchSize := 100
|
|
if numMessages-len(allMessages) < 100 {
|
|
batchSize = numMessages - len(allMessages)
|
|
}
|
|
|
|
batch, err := s.ChannelMessages(channelID, batchSize, lastMessageID, "", "")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(batch) == 0 {
|
|
break // No more messages to fetch
|
|
}
|
|
|
|
// Filter out bot messages and empty messages during fetch
|
|
for _, msg := range batch {
|
|
if !msg.Author.Bot && len(strings.TrimSpace(msg.Content)) > 0 {
|
|
allMessages = append(allMessages, msg)
|
|
}
|
|
}
|
|
|
|
lastMessageID = batch[len(batch)-1].ID
|
|
|
|
if len(batch) < 100 {
|
|
break // Less than 100 messages returned, we've reached the end
|
|
}
|
|
}
|
|
|
|
return allMessages, nil
|
|
}
|
|
|
|
// cleanText removes URLs, mentions, emojis, and normalizes text
|
|
func cleanText(text string) string {
|
|
// Remove URLs
|
|
text = urlRegex.ReplaceAllString(text, "")
|
|
// Remove mentions
|
|
text = mentionRegex.ReplaceAllString(text, "")
|
|
// Remove custom emojis
|
|
text = emojiRegex.ReplaceAllString(text, "")
|
|
// Normalize whitespace
|
|
text = strings.Join(strings.Fields(text), " ")
|
|
return strings.TrimSpace(text)
|
|
}
|
|
|
|
// buildMarkovChain creates an improved Markov chain from a list of messages
|
|
func buildMarkovChain(messages []*discordgo.Message) (map[string][]string, map[string]map[string][]string, map[string]map[string]map[string][]string, map[string]map[string]map[string]map[string][]string, map[string]map[string]map[string]map[string]map[string][]string) {
|
|
chain := make(map[string][]string)
|
|
twoGramChain := make(map[string]map[string][]string)
|
|
threeGramChain := make(map[string]map[string]map[string][]string)
|
|
fourGramChain := make(map[string]map[string]map[string]map[string][]string)
|
|
fiveGramChain := make(map[string]map[string]map[string]map[string]map[string][]string)
|
|
|
|
// Count total words for memory estimation
|
|
totalWords := 0
|
|
for _, msg := range messages {
|
|
cleanedContent := cleanText(msg.Content)
|
|
if len(cleanedContent) >= 3 {
|
|
words := strings.Fields(cleanedContent)
|
|
totalWords += len(words)
|
|
}
|
|
}
|
|
|
|
// Estimate memory usage and adjust max n-gram level
|
|
maxNGram := lib.AppConfig.MarkovMaxNGram
|
|
estimatedMemoryMB := estimateMemoryUsage(totalWords, maxNGram)
|
|
if estimatedMemoryMB > lib.AppConfig.MarkovMemoryLimit {
|
|
// Reduce n-gram level to stay within memory limits
|
|
for maxNGram > 2 && estimateMemoryUsage(totalWords, maxNGram) > lib.AppConfig.MarkovMemoryLimit {
|
|
maxNGram--
|
|
}
|
|
}
|
|
|
|
for _, msg := range messages {
|
|
cleanedContent := cleanText(msg.Content)
|
|
if len(cleanedContent) < 3 {
|
|
continue
|
|
}
|
|
|
|
words := strings.Fields(cleanedContent)
|
|
if len(words) < 2 {
|
|
continue
|
|
}
|
|
|
|
// Build 1-gram chain
|
|
for i := 0; i < len(words)-1; i++ {
|
|
currentWord := strings.ToLower(words[i])
|
|
nextWord := words[i+1]
|
|
|
|
if len(currentWord) < 2 || strings.ContainsAny(currentWord, "!@#$%^&*()[]{}") {
|
|
continue
|
|
}
|
|
|
|
chain[currentWord] = append(chain[currentWord], nextWord)
|
|
}
|
|
|
|
// Build 2-gram chain
|
|
if maxNGram >= 2 {
|
|
for i := 0; i < len(words)-2; i++ {
|
|
word1 := strings.ToLower(words[i])
|
|
word2 := strings.ToLower(words[i+1])
|
|
nextWord := words[i+2]
|
|
|
|
if len(word1) < 2 || len(word2) < 2 ||
|
|
strings.ContainsAny(word1, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word2, "!@#$%^&*()[]{}") {
|
|
continue
|
|
}
|
|
|
|
if twoGramChain[word1] == nil {
|
|
twoGramChain[word1] = make(map[string][]string)
|
|
}
|
|
twoGramChain[word1][word2] = append(twoGramChain[word1][word2], nextWord)
|
|
}
|
|
}
|
|
|
|
// Build 3-gram chain
|
|
if maxNGram >= 3 {
|
|
for i := 0; i < len(words)-3; i++ {
|
|
word1 := strings.ToLower(words[i])
|
|
word2 := strings.ToLower(words[i+1])
|
|
word3 := strings.ToLower(words[i+2])
|
|
nextWord := words[i+3]
|
|
|
|
if len(word1) < 2 || len(word2) < 2 || len(word3) < 2 ||
|
|
strings.ContainsAny(word1, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word2, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word3, "!@#$%^&*()[]{}") {
|
|
continue
|
|
}
|
|
|
|
if threeGramChain[word1] == nil {
|
|
threeGramChain[word1] = make(map[string]map[string][]string)
|
|
}
|
|
if threeGramChain[word1][word2] == nil {
|
|
threeGramChain[word1][word2] = make(map[string][]string)
|
|
}
|
|
threeGramChain[word1][word2][word3] = append(threeGramChain[word1][word2][word3], nextWord)
|
|
}
|
|
}
|
|
|
|
// Build 4-gram chain
|
|
if maxNGram >= 4 {
|
|
for i := 0; i < len(words)-4; i++ {
|
|
word1 := strings.ToLower(words[i])
|
|
word2 := strings.ToLower(words[i+1])
|
|
word3 := strings.ToLower(words[i+2])
|
|
word4 := strings.ToLower(words[i+3])
|
|
nextWord := words[i+4]
|
|
|
|
if len(word1) < 2 || len(word2) < 2 || len(word3) < 2 || len(word4) < 2 ||
|
|
strings.ContainsAny(word1, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word2, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word3, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word4, "!@#$%^&*()[]{}") {
|
|
continue
|
|
}
|
|
|
|
if fourGramChain[word1] == nil {
|
|
fourGramChain[word1] = make(map[string]map[string]map[string][]string)
|
|
}
|
|
if fourGramChain[word1][word2] == nil {
|
|
fourGramChain[word1][word2] = make(map[string]map[string][]string)
|
|
}
|
|
if fourGramChain[word1][word2][word3] == nil {
|
|
fourGramChain[word1][word2][word3] = make(map[string][]string)
|
|
}
|
|
fourGramChain[word1][word2][word3][word4] = append(fourGramChain[word1][word2][word3][word4], nextWord)
|
|
}
|
|
}
|
|
|
|
// Build 5-gram chain for maximum coherence
|
|
if maxNGram >= 5 {
|
|
for i := 0; i < len(words)-5; i++ {
|
|
word1 := strings.ToLower(words[i])
|
|
word2 := strings.ToLower(words[i+1])
|
|
word3 := strings.ToLower(words[i+2])
|
|
word4 := strings.ToLower(words[i+3])
|
|
word5 := strings.ToLower(words[i+4])
|
|
nextWord := words[i+5]
|
|
|
|
if len(word1) < 2 || len(word2) < 2 || len(word3) < 2 || len(word4) < 2 || len(word5) < 2 ||
|
|
strings.ContainsAny(word1, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word2, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word3, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word4, "!@#$%^&*()[]{}") ||
|
|
strings.ContainsAny(word5, "!@#$%^&*()[]{}") {
|
|
continue
|
|
}
|
|
|
|
if fiveGramChain[word1] == nil {
|
|
fiveGramChain[word1] = make(map[string]map[string]map[string]map[string][]string)
|
|
}
|
|
if fiveGramChain[word1][word2] == nil {
|
|
fiveGramChain[word1][word2] = make(map[string]map[string]map[string][]string)
|
|
}
|
|
if fiveGramChain[word1][word2][word3] == nil {
|
|
fiveGramChain[word1][word2][word3] = make(map[string]map[string][]string)
|
|
}
|
|
if fiveGramChain[word1][word2][word3][word4] == nil {
|
|
fiveGramChain[word1][word2][word3][word4] = make(map[string][]string)
|
|
}
|
|
fiveGramChain[word1][word2][word3][word4][word5] = append(fiveGramChain[word1][word2][word3][word4][word5], nextWord)
|
|
}
|
|
}
|
|
}
|
|
|
|
return chain, twoGramChain, threeGramChain, fourGramChain, fiveGramChain
|
|
}
|
|
|
|
// estimateMemoryUsage estimates memory usage in MB for given word count and n-gram level
|
|
func estimateMemoryUsage(wordCount int, maxNGram int) int {
|
|
// Rough estimates based on typical Discord channel patterns
|
|
baseMB := wordCount / 1000 // ~1MB per 1000 words for 1-gram
|
|
|
|
switch maxNGram {
|
|
case 2:
|
|
return baseMB * 5
|
|
case 3:
|
|
return baseMB * 15
|
|
case 4:
|
|
return baseMB * 35
|
|
case 5:
|
|
return baseMB * 75
|
|
case 6:
|
|
return baseMB * 150
|
|
default:
|
|
return baseMB
|
|
}
|
|
}
|
|
|
|
// generateMessage creates a new message using the Markov chain with improved logic
|
|
func generateMessage(chain map[string][]string) string {
|
|
if len(chain) == 0 {
|
|
return ""
|
|
}
|
|
|
|
words := []string{}
|
|
var currentWord string
|
|
|
|
// Start with a random word that has good follow-ups
|
|
attempts := 0
|
|
for word, nextWords := range chain {
|
|
if len(nextWords) >= 2 && len(word) > 2 { // Prefer words with multiple options
|
|
currentWord = word
|
|
break
|
|
}
|
|
attempts++
|
|
if attempts > 50 { // Fallback to any word
|
|
currentWord = word
|
|
break
|
|
}
|
|
}
|
|
|
|
if currentWord == "" {
|
|
return ""
|
|
}
|
|
|
|
// Generate between 5 and 25 words
|
|
maxWords := 5 + rand.Intn(20)
|
|
for i := 0; i < maxWords; i++ {
|
|
// Add current word (capitalize first word)
|
|
if i == 0 {
|
|
words = append(words, strings.Title(currentWord))
|
|
} else {
|
|
words = append(words, currentWord)
|
|
}
|
|
|
|
if nextWords, ok := chain[strings.ToLower(currentWord)]; ok && len(nextWords) > 0 {
|
|
// Randomly select the next word from the possible follow-ups
|
|
currentWord = nextWords[rand.Intn(len(nextWords))]
|
|
} else {
|
|
// Try to find a new starting point
|
|
found := false
|
|
for word, nextWords := range chain {
|
|
if len(nextWords) > 0 && len(word) > 2 {
|
|
currentWord = word
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
result := strings.Join(words, " ")
|
|
|
|
// Add punctuation if missing
|
|
if len(result) > 0 && !strings.ContainsAny(result[len(result)-1:], ".!?") {
|
|
// Randomly add punctuation
|
|
punctuation := []string{".", "!", "?"}
|
|
result += punctuation[rand.Intn(len(punctuation))]
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
func init() {
|
|
// Seed random number generator
|
|
rand.Seed(time.Now().UnixNano())
|
|
}
|
|
|
|
// MarkovQuestionCommand generates a markov chain answer to a question based on channel contents
|
|
func MarkovQuestionCommand(s *discordgo.Session, i *discordgo.InteractionCreate) (string, error) {
|
|
channelID := i.ChannelID
|
|
|
|
var question string
|
|
var numMessages int = lib.AppConfig.MarkovDefaultMessages
|
|
|
|
for _, option := range i.ApplicationCommandData().Options {
|
|
switch option.Name {
|
|
case "question":
|
|
question = option.StringValue()
|
|
case "messages":
|
|
numMessages = int(option.IntValue())
|
|
if numMessages <= 0 {
|
|
numMessages = lib.AppConfig.MarkovDefaultMessages
|
|
} else if numMessages > lib.AppConfig.MarkovMaxMessages {
|
|
numMessages = lib.AppConfig.MarkovMaxMessages
|
|
}
|
|
}
|
|
}
|
|
|
|
if question == "" {
|
|
return "Please provide a question!", nil
|
|
}
|
|
|
|
cacheKey := fmt.Sprintf("%s:%d", channelID, numMessages)
|
|
var chain map[string][]string
|
|
var twoGramChain map[string]map[string][]string
|
|
var threeGramChain map[string]map[string]map[string][]string
|
|
var fourGramChain map[string]map[string]map[string]map[string][]string
|
|
var fiveGramChain map[string]map[string]map[string]map[string]map[string][]string
|
|
|
|
if cachedChain := getCachedChain(cacheKey); cachedChain != nil {
|
|
chain = cachedChain
|
|
twoGramChain = getCachedTwoGramChain(cacheKey)
|
|
threeGramChain = getCachedThreeGramChain(cacheKey)
|
|
fourGramChain = getCachedFourGramChain(cacheKey)
|
|
fiveGramChain = getCachedFiveGramChain(cacheKey)
|
|
} else {
|
|
allMessages, err := fetchMessages(s, channelID, numMessages)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
chain, twoGramChain, threeGramChain, fourGramChain, fiveGramChain = buildMarkovChain(allMessages)
|
|
setCachedChain(cacheKey, chain, twoGramChain, threeGramChain, fourGramChain, fiveGramChain, allMessages)
|
|
}
|
|
|
|
answer := generateAdvancedQuestionAnswer(chain, twoGramChain, threeGramChain, fourGramChain, fiveGramChain, question)
|
|
|
|
if answer == "" {
|
|
answer = "I couldn't generate an answer to that question. The channel might not have enough relevant content."
|
|
}
|
|
|
|
return fmt.Sprintf("**Q:** %s\n**A:** %s", question, answer), nil
|
|
}
|
|
|
|
// generateQuestionAnswer generates a markov chain response that attempts to answer the given question
|
|
func generateQuestionAnswer(chain map[string][]string, twoGramChain map[string]map[string][]string, question string) string {
|
|
if len(chain) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Clean and analyze the question to find relevant starting words
|
|
cleanedQuestion := cleanText(question)
|
|
questionWords := strings.Fields(strings.ToLower(cleanedQuestion))
|
|
|
|
// Categorize the question type for better response generation
|
|
questionType := categorizeQuestion(cleanedQuestion)
|
|
|
|
// Find potential starting words with weighted scoring
|
|
startingCandidates := findBestStartingWords(chain, questionWords, questionType)
|
|
|
|
if len(startingCandidates) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Generate response using the best starting candidate
|
|
return generateCoherentResponse(chain, twoGramChain, startingCandidates, questionType)
|
|
}
|
|
|
|
// categorizeQuestion determines the type of question for better response generation
|
|
func categorizeQuestion(question string) string {
|
|
question = strings.ToLower(question)
|
|
|
|
if strings.Contains(question, "what") {
|
|
return "what"
|
|
} else if strings.Contains(question, "how") {
|
|
return "how"
|
|
} else if strings.Contains(question, "why") {
|
|
return "why"
|
|
} else if strings.Contains(question, "when") {
|
|
return "when"
|
|
} else if strings.Contains(question, "where") {
|
|
return "where"
|
|
} else if strings.Contains(question, "who") {
|
|
return "who"
|
|
} else if strings.Contains(question, "which") {
|
|
return "which"
|
|
} else if strings.Contains(question, "is") || strings.Contains(question, "are") || strings.Contains(question, "do") || strings.Contains(question, "does") {
|
|
return "yesno"
|
|
}
|
|
|
|
return "general"
|
|
}
|
|
|
|
// WordCandidate represents a potential starting word with its relevance score
|
|
type WordCandidate struct {
|
|
Word string
|
|
Score int
|
|
}
|
|
|
|
// findBestStartingWords finds and scores potential starting words based on question relevance
|
|
func findBestStartingWords(chain map[string][]string, questionWords []string, questionType string) []WordCandidate {
|
|
candidates := make(map[string]int)
|
|
|
|
// Score words from the question that exist in our chain
|
|
for _, word := range questionWords {
|
|
if len(word) > 2 && !isStopWord(word) {
|
|
if nextWords, exists := chain[word]; exists && len(nextWords) > 0 {
|
|
candidates[word] += 10 // High score for direct question words
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add contextually relevant words based on question type
|
|
contextWords := getContextualWords(questionType)
|
|
for _, word := range contextWords {
|
|
if nextWords, exists := chain[word]; exists && len(nextWords) > 0 {
|
|
candidates[word] += 5 // Medium score for contextual words
|
|
}
|
|
}
|
|
|
|
// Add high-frequency words as fallback
|
|
for word, nextWords := range chain {
|
|
if len(nextWords) >= 3 && len(word) > 2 && !isStopWord(word) {
|
|
if _, exists := candidates[word]; !exists {
|
|
candidates[word] = len(nextWords) / 2 // Score based on frequency
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert to sorted slice
|
|
var result []WordCandidate
|
|
for word, score := range candidates {
|
|
result = append(result, WordCandidate{Word: word, Score: score})
|
|
}
|
|
|
|
// Sort by score (highest first)
|
|
for i := 0; i < len(result)-1; i++ {
|
|
for j := i + 1; j < len(result); j++ {
|
|
if result[j].Score > result[i].Score {
|
|
result[i], result[j] = result[j], result[i]
|
|
}
|
|
}
|
|
}
|
|
|
|
// Return top candidates
|
|
if len(result) > 10 {
|
|
result = result[:10]
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
// getContextualWords returns words that are contextually relevant to the question type
|
|
func getContextualWords(questionType string) []string {
|
|
switch questionType {
|
|
case "what":
|
|
return []string{"thing", "something", "object", "idea", "concept", "stuff", "item"}
|
|
case "how":
|
|
return []string{"way", "method", "process", "steps", "technique", "approach"}
|
|
case "why":
|
|
return []string{"because", "reason", "cause", "since", "due", "explanation"}
|
|
case "when":
|
|
return []string{"time", "moment", "day", "hour", "yesterday", "today", "tomorrow", "now", "then"}
|
|
case "where":
|
|
return []string{"place", "location", "here", "there", "somewhere", "anywhere"}
|
|
case "who":
|
|
return []string{"person", "people", "someone", "anyone", "everybody", "nobody"}
|
|
case "which":
|
|
return []string{"choice", "option", "selection", "pick", "prefer"}
|
|
case "yesno":
|
|
return []string{"yes", "no", "maybe", "definitely", "probably", "possibly", "sure", "absolutely"}
|
|
default:
|
|
return []string{"think", "believe", "know", "understand", "feel", "seem"}
|
|
}
|
|
}
|
|
|
|
// generateCoherentResponse creates a more coherent response using improved algorithms
|
|
func generateCoherentResponse(chain map[string][]string, twoGramChain map[string]map[string][]string, candidates []WordCandidate, questionType string) string {
|
|
if len(candidates) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Try multiple generation attempts and pick the best one
|
|
var bestResponse string
|
|
bestScore := 0
|
|
|
|
for attempt := 0; attempt < 3; attempt++ {
|
|
// Select starting word (bias towards higher scored candidates)
|
|
candidateIndex := 0
|
|
if len(candidates) > 1 {
|
|
// 70% chance to pick top candidate, 30% for others
|
|
if rand.Float32() > 0.7 && len(candidates) > 1 {
|
|
candidateIndex = rand.Intn(min(3, len(candidates)))
|
|
}
|
|
}
|
|
|
|
currentWord := candidates[candidateIndex].Word
|
|
words := []string{}
|
|
|
|
// Generate response with improved coherence
|
|
maxWords := 8 + rand.Intn(22) // 8-22 words
|
|
lastWord := ""
|
|
|
|
for i := 0; i < maxWords; i++ {
|
|
// Add current word (capitalize first word)
|
|
if i == 0 {
|
|
words = append(words, strings.Title(currentWord))
|
|
} else {
|
|
words = append(words, currentWord)
|
|
}
|
|
|
|
var nextWord string
|
|
|
|
// Try 2-gram chain first for better coherence
|
|
if lastWord != "" {
|
|
if twoGramOptions, exists := twoGramChain[strings.ToLower(lastWord)][strings.ToLower(currentWord)]; exists && len(twoGramOptions) > 0 {
|
|
nextWord = twoGramOptions[rand.Intn(len(twoGramOptions))]
|
|
}
|
|
}
|
|
|
|
// Fallback to regular chain
|
|
if nextWord == "" {
|
|
if nextWords, exists := chain[strings.ToLower(currentWord)]; exists && len(nextWords) > 0 {
|
|
// Prefer longer words for better content
|
|
var goodOptions []string
|
|
for _, word := range nextWords {
|
|
if len(word) > 2 && !isStopWord(strings.ToLower(word)) {
|
|
goodOptions = append(goodOptions, word)
|
|
}
|
|
}
|
|
|
|
if len(goodOptions) > 0 {
|
|
nextWord = goodOptions[rand.Intn(len(goodOptions))]
|
|
} else {
|
|
nextWord = nextWords[rand.Intn(len(nextWords))]
|
|
}
|
|
}
|
|
}
|
|
|
|
// If we can't find a next word, try to restart with a good candidate
|
|
if nextWord == "" {
|
|
found := false
|
|
for _, candidate := range candidates {
|
|
if nextWords, exists := chain[candidate.Word]; exists && len(nextWords) > 0 {
|
|
nextWord = candidate.Word
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
break
|
|
}
|
|
}
|
|
|
|
lastWord = currentWord
|
|
currentWord = nextWord
|
|
}
|
|
|
|
response := strings.Join(words, " ")
|
|
|
|
// Score this response
|
|
score := scoreResponse(response, questionType)
|
|
|
|
if score > bestScore {
|
|
bestScore = score
|
|
bestResponse = response
|
|
}
|
|
}
|
|
|
|
// Add appropriate punctuation
|
|
if len(bestResponse) > 0 && !strings.ContainsAny(bestResponse[len(bestResponse)-1:], ".!?") {
|
|
punctuation := getPunctuationForQuestionType(questionType)
|
|
bestResponse += punctuation[rand.Intn(len(punctuation))]
|
|
}
|
|
|
|
return bestResponse
|
|
}
|
|
|
|
// scoreResponse scores a response based on various quality metrics
|
|
func scoreResponse(response string, questionType string) int {
|
|
score := 0
|
|
words := strings.Fields(response)
|
|
|
|
// Length score (prefer 8-16 words)
|
|
if len(words) >= 8 && len(words) <= 16 {
|
|
score += 10
|
|
} else if len(words) >= 6 && len(words) <= 20 {
|
|
score += 5
|
|
}
|
|
|
|
// Diversity score (prefer responses with varied word lengths)
|
|
totalLength := 0
|
|
for _, word := range words {
|
|
totalLength += len(word)
|
|
}
|
|
avgWordLength := float64(totalLength) / float64(len(words))
|
|
if avgWordLength > 3.5 && avgWordLength < 6.0 {
|
|
score += 5
|
|
}
|
|
|
|
// Content word score (prefer responses with meaningful words)
|
|
contentWords := 0
|
|
for _, word := range words {
|
|
if len(word) > 3 && !isStopWord(strings.ToLower(word)) {
|
|
contentWords++
|
|
}
|
|
}
|
|
score += contentWords
|
|
|
|
return score
|
|
}
|
|
|
|
// getPunctuationForQuestionType returns appropriate punctuation for the question type
|
|
func getPunctuationForQuestionType(questionType string) []string {
|
|
switch questionType {
|
|
case "yesno":
|
|
return []string{".", "!", "."}
|
|
case "why", "how":
|
|
return []string{".", ".", "!"}
|
|
default:
|
|
return []string{".", ".", "!", "."}
|
|
}
|
|
}
|
|
|
|
// min returns the minimum of two integers
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// isStopWord checks if a word is a common stop word that shouldn't be used as starting points
|
|
func isStopWord(word string) bool {
|
|
stopWords := map[string]bool{
|
|
"a": true, "an": true, "and": true, "are": true, "as": true, "at": true, "be": true, "by": true,
|
|
"for": true, "from": true, "has": true, "he": true, "in": true, "is": true, "it": true,
|
|
"its": true, "of": true, "on": true, "that": true, "the": true, "to": true, "was": true,
|
|
"will": true, "with": true, "or": true, "but": true, "if": true, "so": true, "do": true,
|
|
}
|
|
return stopWords[word]
|
|
}
|
|
|
|
// buildTwoGramChain creates a 2-gram chain for better sentence flow from existing 1-gram chain
|
|
func buildTwoGramChain(chain map[string][]string) map[string]map[string][]string {
|
|
// This creates transitions between word pairs from the 1-gram chain
|
|
twoGramChain := make(map[string]map[string][]string)
|
|
|
|
for word1, nextWords := range chain {
|
|
for _, word2 := range nextWords {
|
|
if twoGramChain[word1] == nil {
|
|
twoGramChain[word1] = make(map[string][]string)
|
|
}
|
|
// For each word2 that follows word1, find what follows word2
|
|
if nextNextWords, exists := chain[strings.ToLower(word2)]; exists {
|
|
twoGramChain[word1][strings.ToLower(word2)] = nextNextWords
|
|
}
|
|
}
|
|
}
|
|
|
|
return twoGramChain
|
|
}
|
|
|
|
// generateImprovedMessage creates a new message using both 1-gram and 2-gram chains for better coherence
|
|
func generateImprovedMessage(chain map[string][]string, twoGramChain map[string]map[string][]string) string {
|
|
if len(chain) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Try multiple generation attempts and pick the best one
|
|
var bestMessage string
|
|
bestScore := 0
|
|
|
|
for attempt := 0; attempt < 3; attempt++ {
|
|
words := []string{}
|
|
var currentWord string
|
|
|
|
// Start with a random word that has good follow-ups
|
|
attempts := 0
|
|
for word, nextWords := range chain {
|
|
if len(nextWords) >= 2 && len(word) > 2 && !isStopWord(word) { // Prefer words with multiple options
|
|
currentWord = word
|
|
break
|
|
}
|
|
attempts++
|
|
if attempts > 50 { // Fallback to any word
|
|
currentWord = word
|
|
break
|
|
}
|
|
}
|
|
|
|
if currentWord == "" {
|
|
continue
|
|
}
|
|
|
|
// Generate between 8 and 20 words for better content
|
|
maxWords := 8 + rand.Intn(12)
|
|
lastWord := ""
|
|
|
|
for i := 0; i < maxWords; i++ {
|
|
// Add current word (capitalize first word)
|
|
if i == 0 {
|
|
words = append(words, strings.Title(currentWord))
|
|
} else {
|
|
words = append(words, currentWord)
|
|
}
|
|
|
|
var nextWord string
|
|
|
|
// Try 2-gram chain first for better coherence
|
|
if lastWord != "" && twoGramChain != nil {
|
|
if twoGramOptions, exists := twoGramChain[strings.ToLower(lastWord)][strings.ToLower(currentWord)]; exists && len(twoGramOptions) > 0 {
|
|
nextWord = twoGramOptions[rand.Intn(len(twoGramOptions))]
|
|
}
|
|
}
|
|
|
|
// Fallback to regular chain
|
|
if nextWord == "" {
|
|
if nextWords, exists := chain[strings.ToLower(currentWord)]; exists && len(nextWords) > 0 {
|
|
// Prefer longer, more meaningful words
|
|
var goodOptions []string
|
|
for _, word := range nextWords {
|
|
if len(word) > 2 && !isStopWord(strings.ToLower(word)) {
|
|
goodOptions = append(goodOptions, word)
|
|
}
|
|
}
|
|
|
|
if len(goodOptions) > 0 {
|
|
nextWord = goodOptions[rand.Intn(len(goodOptions))]
|
|
} else {
|
|
nextWord = nextWords[rand.Intn(len(nextWords))]
|
|
}
|
|
}
|
|
}
|
|
|
|
// If we can't find a next word, try to restart
|
|
if nextWord == "" {
|
|
found := false
|
|
for word, nextWords := range chain {
|
|
if len(nextWords) > 0 && len(word) > 2 && !isStopWord(word) {
|
|
nextWord = word
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
break
|
|
}
|
|
}
|
|
|
|
lastWord = currentWord
|
|
currentWord = nextWord
|
|
}
|
|
|
|
message := strings.Join(words, " ")
|
|
|
|
// Score this message
|
|
score := scoreGeneratedMessage(message)
|
|
|
|
if score > bestScore {
|
|
bestScore = score
|
|
bestMessage = message
|
|
}
|
|
}
|
|
|
|
// Add punctuation if missing
|
|
if len(bestMessage) > 0 && !strings.ContainsAny(bestMessage[len(bestMessage)-1:], ".!?") {
|
|
// Randomly add punctuation
|
|
punctuation := []string{".", "!", "?", "."}
|
|
bestMessage += punctuation[rand.Intn(len(punctuation))]
|
|
}
|
|
|
|
return bestMessage
|
|
}
|
|
|
|
// scoreGeneratedMessage scores a generated message based on quality metrics
|
|
func scoreGeneratedMessage(message string) int {
|
|
score := 0
|
|
words := strings.Fields(message)
|
|
|
|
// Length score (prefer 8-16 words)
|
|
if len(words) >= 8 && len(words) <= 16 {
|
|
score += 10
|
|
} else if len(words) >= 6 && len(words) <= 20 {
|
|
score += 5
|
|
}
|
|
|
|
// Diversity score (prefer responses with varied word lengths)
|
|
totalLength := 0
|
|
for _, word := range words {
|
|
totalLength += len(word)
|
|
}
|
|
if len(words) > 0 {
|
|
avgWordLength := float64(totalLength) / float64(len(words))
|
|
if avgWordLength > 3.0 && avgWordLength < 7.0 {
|
|
score += 5
|
|
}
|
|
}
|
|
|
|
// Content word score (prefer messages with meaningful words)
|
|
contentWords := 0
|
|
for _, word := range words {
|
|
if len(word) > 3 && !isStopWord(strings.ToLower(word)) {
|
|
contentWords++
|
|
}
|
|
}
|
|
score += contentWords
|
|
|
|
return score
|
|
}
|
|
|
|
// generateAdvancedMessage creates a new message using all n-gram chains for maximum coherence
|
|
func generateAdvancedMessage(chain map[string][]string, twoGramChain map[string]map[string][]string, threeGramChain map[string]map[string]map[string][]string, fourGramChain map[string]map[string]map[string]map[string][]string, fiveGramChain map[string]map[string]map[string]map[string]map[string][]string) string {
|
|
if len(chain) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Try multiple generation attempts and pick the best one
|
|
var bestMessage string
|
|
bestScore := 0
|
|
|
|
for attempt := 0; attempt < 5; attempt++ {
|
|
words := []string{}
|
|
var currentWord string
|
|
|
|
// Start with a random word that has good follow-ups
|
|
attempts := 0
|
|
for word, nextWords := range chain {
|
|
if len(nextWords) >= 2 && len(word) > 2 && !isStopWord(word) {
|
|
currentWord = word
|
|
break
|
|
}
|
|
attempts++
|
|
if attempts > 50 {
|
|
currentWord = word
|
|
break
|
|
}
|
|
}
|
|
|
|
if currentWord == "" {
|
|
continue
|
|
}
|
|
|
|
// Generate between 10 and 18 words for better content
|
|
maxWords := 10 + rand.Intn(8)
|
|
wordHistory := []string{currentWord}
|
|
|
|
for i := 0; i < maxWords; i++ {
|
|
// Add current word (capitalize first word)
|
|
if i == 0 {
|
|
words = append(words, strings.Title(currentWord))
|
|
} else {
|
|
words = append(words, currentWord)
|
|
}
|
|
|
|
var nextWord string
|
|
historyLen := len(wordHistory)
|
|
|
|
// Try 5-gram chain first (highest coherence)
|
|
if historyLen >= 5 && fiveGramChain != nil {
|
|
w1, w2, w3, w4, w5 := strings.ToLower(wordHistory[historyLen-5]), strings.ToLower(wordHistory[historyLen-4]), strings.ToLower(wordHistory[historyLen-3]), strings.ToLower(wordHistory[historyLen-2]), strings.ToLower(wordHistory[historyLen-1])
|
|
if options, exists := fiveGramChain[w1][w2][w3][w4][w5]; exists && len(options) > 0 {
|
|
nextWord = selectBestNextWord(options, wordHistory)
|
|
}
|
|
}
|
|
|
|
// Try 4-gram chain if 5-gram failed
|
|
if nextWord == "" && historyLen >= 4 && fourGramChain != nil {
|
|
w1, w2, w3, w4 := strings.ToLower(wordHistory[historyLen-4]), strings.ToLower(wordHistory[historyLen-3]), strings.ToLower(wordHistory[historyLen-2]), strings.ToLower(wordHistory[historyLen-1])
|
|
if options, exists := fourGramChain[w1][w2][w3][w4]; exists && len(options) > 0 {
|
|
nextWord = selectBestNextWord(options, wordHistory)
|
|
}
|
|
}
|
|
|
|
// Try 3-gram chain if 4-gram failed
|
|
if nextWord == "" && historyLen >= 3 && threeGramChain != nil {
|
|
w1, w2, w3 := strings.ToLower(wordHistory[historyLen-3]), strings.ToLower(wordHistory[historyLen-2]), strings.ToLower(wordHistory[historyLen-1])
|
|
if options, exists := threeGramChain[w1][w2][w3]; exists && len(options) > 0 {
|
|
nextWord = selectBestNextWord(options, wordHistory)
|
|
}
|
|
}
|
|
|
|
// Try 2-gram chain if 3-gram failed
|
|
if nextWord == "" && historyLen >= 2 && twoGramChain != nil {
|
|
w1, w2 := strings.ToLower(wordHistory[historyLen-2]), strings.ToLower(wordHistory[historyLen-1])
|
|
if options, exists := twoGramChain[w1][w2]; exists && len(options) > 0 {
|
|
nextWord = selectBestNextWord(options, wordHistory)
|
|
}
|
|
}
|
|
|
|
// Fallback to 1-gram chain
|
|
if nextWord == "" {
|
|
if nextWords, exists := chain[strings.ToLower(currentWord)]; exists && len(nextWords) > 0 {
|
|
nextWord = selectBestNextWord(nextWords, wordHistory)
|
|
}
|
|
}
|
|
|
|
// If we still can't find a next word, try to restart
|
|
if nextWord == "" {
|
|
found := false
|
|
for word, nextWords := range chain {
|
|
if len(nextWords) > 0 && len(word) > 2 && !isStopWord(word) {
|
|
nextWord = word
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
break
|
|
}
|
|
}
|
|
|
|
currentWord = nextWord
|
|
wordHistory = append(wordHistory, currentWord)
|
|
|
|
// Keep history manageable
|
|
if len(wordHistory) > 10 {
|
|
wordHistory = wordHistory[1:]
|
|
}
|
|
}
|
|
|
|
message := strings.Join(words, " ")
|
|
|
|
// Score this message with enhanced scoring
|
|
score := scoreAdvancedMessage(message)
|
|
|
|
if score > bestScore {
|
|
bestScore = score
|
|
bestMessage = message
|
|
}
|
|
}
|
|
|
|
// Add punctuation if missing
|
|
if len(bestMessage) > 0 && !strings.ContainsAny(bestMessage[len(bestMessage)-1:], ".!?") {
|
|
punctuation := []string{".", "!", "?", "."}
|
|
bestMessage += punctuation[rand.Intn(len(punctuation))]
|
|
}
|
|
|
|
return bestMessage
|
|
}
|
|
|
|
// generateAdvancedQuestionAnswer generates a markov chain response using all n-gram levels
|
|
func generateAdvancedQuestionAnswer(chain map[string][]string, twoGramChain map[string]map[string][]string, threeGramChain map[string]map[string]map[string][]string, fourGramChain map[string]map[string]map[string]map[string][]string, fiveGramChain map[string]map[string]map[string]map[string]map[string][]string, question string) string {
|
|
if len(chain) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Clean and analyze the question to find relevant starting words
|
|
cleanedQuestion := cleanText(question)
|
|
questionWords := strings.Fields(strings.ToLower(cleanedQuestion))
|
|
|
|
// Categorize the question type for better response generation
|
|
questionType := categorizeQuestion(cleanedQuestion)
|
|
|
|
// Find potential starting words with weighted scoring
|
|
startingCandidates := findBestStartingWords(chain, questionWords, questionType)
|
|
|
|
if len(startingCandidates) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Generate response using the best starting candidate with advanced n-gram chains
|
|
return generateAdvancedCoherentResponse(chain, twoGramChain, threeGramChain, fourGramChain, fiveGramChain, startingCandidates, questionType)
|
|
}
|
|
|
|
// generateAdvancedCoherentResponse creates a more coherent response using all n-gram levels
|
|
func generateAdvancedCoherentResponse(chain map[string][]string, twoGramChain map[string]map[string][]string, threeGramChain map[string]map[string]map[string][]string, fourGramChain map[string]map[string]map[string]map[string][]string, fiveGramChain map[string]map[string]map[string]map[string]map[string][]string, candidates []WordCandidate, questionType string) string {
|
|
if len(candidates) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Try multiple generation attempts and pick the best one
|
|
var bestResponse string
|
|
bestScore := 0
|
|
|
|
for attempt := 0; attempt < 5; attempt++ {
|
|
// Select starting word (bias towards higher scored candidates)
|
|
candidateIndex := 0
|
|
if len(candidates) > 1 {
|
|
// 70% chance to pick top candidate, 30% for others
|
|
if rand.Float32() > 0.7 && len(candidates) > 1 {
|
|
candidateIndex = rand.Intn(min(3, len(candidates)))
|
|
}
|
|
}
|
|
|
|
currentWord := candidates[candidateIndex].Word
|
|
words := []string{}
|
|
wordHistory := []string{currentWord}
|
|
|
|
// Generate response with improved coherence using all n-gram levels
|
|
maxWords := 12 + rand.Intn(10) // 12-22 words for substantial answers
|
|
|
|
for i := 0; i < maxWords; i++ {
|
|
// Add current word (capitalize first word)
|
|
if i == 0 {
|
|
words = append(words, strings.Title(currentWord))
|
|
} else {
|
|
words = append(words, currentWord)
|
|
}
|
|
|
|
var nextWord string
|
|
historyLen := len(wordHistory)
|
|
|
|
// Try 5-gram chain first (highest coherence)
|
|
if historyLen >= 5 && fiveGramChain != nil {
|
|
w1, w2, w3, w4, w5 := strings.ToLower(wordHistory[historyLen-5]), strings.ToLower(wordHistory[historyLen-4]), strings.ToLower(wordHistory[historyLen-3]), strings.ToLower(wordHistory[historyLen-2]), strings.ToLower(wordHistory[historyLen-1])
|
|
if options, exists := fiveGramChain[w1][w2][w3][w4][w5]; exists && len(options) > 0 {
|
|
nextWord = selectBestNextWord(options, wordHistory)
|
|
}
|
|
}
|
|
|
|
// Try 4-gram chain if 5-gram failed
|
|
if nextWord == "" && historyLen >= 4 && fourGramChain != nil {
|
|
w1, w2, w3, w4 := strings.ToLower(wordHistory[historyLen-4]), strings.ToLower(wordHistory[historyLen-3]), strings.ToLower(wordHistory[historyLen-2]), strings.ToLower(wordHistory[historyLen-1])
|
|
if options, exists := fourGramChain[w1][w2][w3][w4]; exists && len(options) > 0 {
|
|
nextWord = selectBestNextWord(options, wordHistory)
|
|
}
|
|
}
|
|
|
|
// Try 3-gram chain if 4-gram failed
|
|
if nextWord == "" && historyLen >= 3 && threeGramChain != nil {
|
|
w1, w2, w3 := strings.ToLower(wordHistory[historyLen-3]), strings.ToLower(wordHistory[historyLen-2]), strings.ToLower(wordHistory[historyLen-1])
|
|
if options, exists := threeGramChain[w1][w2][w3]; exists && len(options) > 0 {
|
|
nextWord = selectBestNextWord(options, wordHistory)
|
|
}
|
|
}
|
|
|
|
// Try 2-gram chain if 3-gram failed
|
|
if nextWord == "" && historyLen >= 2 && twoGramChain != nil {
|
|
w1, w2 := strings.ToLower(wordHistory[historyLen-2]), strings.ToLower(wordHistory[historyLen-1])
|
|
if options, exists := twoGramChain[w1][w2]; exists && len(options) > 0 {
|
|
nextWord = selectBestNextWord(options, wordHistory)
|
|
}
|
|
}
|
|
|
|
// Fallback to regular chain with preference for meaningful words
|
|
if nextWord == "" {
|
|
if nextWords, exists := chain[strings.ToLower(currentWord)]; exists && len(nextWords) > 0 {
|
|
nextWord = selectBestNextWord(nextWords, wordHistory)
|
|
}
|
|
}
|
|
|
|
// If we can't find a next word, try to restart with a good candidate
|
|
if nextWord == "" {
|
|
found := false
|
|
for _, candidate := range candidates {
|
|
if nextWords, exists := chain[candidate.Word]; exists && len(nextWords) > 0 {
|
|
nextWord = candidate.Word
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
break
|
|
}
|
|
}
|
|
|
|
currentWord = nextWord
|
|
wordHistory = append(wordHistory, currentWord)
|
|
|
|
// Keep history manageable
|
|
if len(wordHistory) > 10 {
|
|
wordHistory = wordHistory[1:]
|
|
}
|
|
}
|
|
|
|
response := strings.Join(words, " ")
|
|
|
|
// Score this response with enhanced scoring
|
|
score := scoreAdvancedResponse(response, questionType)
|
|
|
|
if score > bestScore {
|
|
bestScore = score
|
|
bestResponse = response
|
|
}
|
|
}
|
|
|
|
// Add appropriate punctuation
|
|
if len(bestResponse) > 0 && !strings.ContainsAny(bestResponse[len(bestResponse)-1:], ".!?") {
|
|
punctuation := getPunctuationForQuestionType(questionType)
|
|
bestResponse += punctuation[rand.Intn(len(punctuation))]
|
|
}
|
|
|
|
return bestResponse
|
|
}
|
|
|
|
// scoreAdvancedMessage scores a generated message with enhanced metrics
|
|
func scoreAdvancedMessage(message string) int {
|
|
score := 0
|
|
words := strings.Fields(message)
|
|
|
|
// Length score (prefer 10-16 words)
|
|
if len(words) >= 10 && len(words) <= 16 {
|
|
score += 15
|
|
} else if len(words) >= 8 && len(words) <= 18 {
|
|
score += 10
|
|
} else if len(words) >= 6 && len(words) <= 20 {
|
|
score += 5
|
|
}
|
|
|
|
// Diversity score (prefer responses with varied word lengths)
|
|
totalLength := 0
|
|
uniqueWords := make(map[string]bool)
|
|
for _, word := range words {
|
|
totalLength += len(word)
|
|
uniqueWords[strings.ToLower(word)] = true
|
|
}
|
|
|
|
if len(words) > 0 {
|
|
avgWordLength := float64(totalLength) / float64(len(words))
|
|
if avgWordLength > 3.5 && avgWordLength < 6.5 {
|
|
score += 8
|
|
}
|
|
|
|
// Uniqueness score (penalize repetition)
|
|
uniqueRatio := float64(len(uniqueWords)) / float64(len(words))
|
|
if uniqueRatio > 0.8 {
|
|
score += 10
|
|
} else if uniqueRatio > 0.6 {
|
|
score += 5
|
|
}
|
|
}
|
|
|
|
// Content word score (prefer messages with meaningful words)
|
|
contentWords := 0
|
|
for _, word := range words {
|
|
if len(word) > 3 && !isStopWord(strings.ToLower(word)) {
|
|
contentWords++
|
|
}
|
|
}
|
|
score += contentWords * 2
|
|
|
|
// Grammar coherence bonus (simple heuristics)
|
|
if !strings.Contains(message, " a a ") && !strings.Contains(message, " the the ") && !strings.Contains(message, " you you ") {
|
|
score += 5
|
|
}
|
|
|
|
return score
|
|
}
|
|
|
|
// scoreAdvancedResponse scores a response with enhanced question-specific metrics
|
|
func scoreAdvancedResponse(response string, questionType string) int {
|
|
score := scoreAdvancedMessage(response) // Base score
|
|
|
|
// Question-specific bonuses
|
|
responseLower := strings.ToLower(response)
|
|
switch questionType {
|
|
case "yesno":
|
|
if strings.Contains(responseLower, "yes") || strings.Contains(responseLower, "no") ||
|
|
strings.Contains(responseLower, "maybe") || strings.Contains(responseLower, "definitely") {
|
|
score += 8
|
|
}
|
|
case "why":
|
|
if strings.Contains(responseLower, "because") || strings.Contains(responseLower, "reason") ||
|
|
strings.Contains(responseLower, "since") || strings.Contains(responseLower, "due") {
|
|
score += 8
|
|
}
|
|
case "how":
|
|
if strings.Contains(responseLower, "way") || strings.Contains(responseLower, "method") ||
|
|
strings.Contains(responseLower, "process") || strings.Contains(responseLower, "steps") {
|
|
score += 8
|
|
}
|
|
case "when":
|
|
if strings.Contains(responseLower, "time") || strings.Contains(responseLower, "day") ||
|
|
strings.Contains(responseLower, "hour") || strings.Contains(responseLower, "moment") {
|
|
score += 8
|
|
}
|
|
case "where":
|
|
if strings.Contains(responseLower, "place") || strings.Contains(responseLower, "location") ||
|
|
strings.Contains(responseLower, "here") || strings.Contains(responseLower, "there") {
|
|
score += 8
|
|
}
|
|
}
|
|
|
|
return score
|
|
}
|
|
|
|
// isValidNextWord checks if a word would create repetitive or grammatical issues
|
|
func isValidNextWord(wordHistory []string, nextWord string) bool {
|
|
if len(wordHistory) == 0 {
|
|
return true
|
|
}
|
|
|
|
nextWordLower := strings.ToLower(nextWord)
|
|
|
|
// Prevent immediate repetition
|
|
if len(wordHistory) >= 1 && strings.ToLower(wordHistory[len(wordHistory)-1]) == nextWordLower {
|
|
return false
|
|
}
|
|
|
|
// Prevent "a a", "the the", "you you" patterns
|
|
if len(wordHistory) >= 1 {
|
|
lastWord := strings.ToLower(wordHistory[len(wordHistory)-1])
|
|
if (lastWord == "a" || lastWord == "the" || lastWord == "you") && lastWord == nextWordLower {
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Prevent triple repetition in recent history
|
|
if len(wordHistory) >= 3 {
|
|
count := 0
|
|
for i := len(wordHistory) - 3; i < len(wordHistory); i++ {
|
|
if strings.ToLower(wordHistory[i]) == nextWordLower {
|
|
count++
|
|
}
|
|
}
|
|
if count >= 2 {
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Prevent common grammatical errors
|
|
if len(wordHistory) >= 1 {
|
|
lastWord := strings.ToLower(wordHistory[len(wordHistory)-1])
|
|
|
|
// Don't put "a" after "you" in most cases
|
|
if lastWord == "you" && nextWordLower == "a" {
|
|
return false
|
|
}
|
|
|
|
// Don't put articles after articles
|
|
if (lastWord == "a" || lastWord == "an" || lastWord == "the") &&
|
|
(nextWordLower == "a" || nextWordLower == "an" || nextWordLower == "the") {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// selectBestNextWord chooses the best next word from available options
|
|
func selectBestNextWord(options []string, wordHistory []string) string {
|
|
if len(options) == 0 {
|
|
return ""
|
|
}
|
|
|
|
// Filter out invalid options
|
|
var validOptions []string
|
|
for _, option := range options {
|
|
if isValidNextWord(wordHistory, option) {
|
|
validOptions = append(validOptions, option)
|
|
}
|
|
}
|
|
|
|
// If no valid options, fall back to original options but try to avoid the worst ones
|
|
if len(validOptions) == 0 {
|
|
var fallbackOptions []string
|
|
for _, option := range options {
|
|
// At least avoid immediate repetition
|
|
if len(wordHistory) == 0 || strings.ToLower(wordHistory[len(wordHistory)-1]) != strings.ToLower(option) {
|
|
fallbackOptions = append(fallbackOptions, option)
|
|
}
|
|
}
|
|
if len(fallbackOptions) > 0 {
|
|
validOptions = fallbackOptions
|
|
} else {
|
|
validOptions = options
|
|
}
|
|
}
|
|
|
|
// Prefer longer, more meaningful words
|
|
var goodOptions []string
|
|
for _, option := range validOptions {
|
|
if len(option) > 2 && !isStopWord(strings.ToLower(option)) {
|
|
goodOptions = append(goodOptions, option)
|
|
}
|
|
}
|
|
|
|
if len(goodOptions) > 0 {
|
|
return goodOptions[rand.Intn(len(goodOptions))]
|
|
}
|
|
|
|
return validOptions[rand.Intn(len(validOptions))]
|
|
}
|