package jlama;

import javax.swing.tree.DefaultTreeModel;
import javax.swing.tree.TreeNode;
import java.awt.*;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class KeywordContainer {
    private final KeywordNode keywordRootNode = new KeywordNode(null);
    private final DefaultTreeModel model = new DefaultTreeModel(keywordRootNode);
    private final NodeChangeListener nodeChangeListener;

    public KeywordContainer(NodeChangeListener nodeChangeListener) {
        this.nodeChangeListener = nodeChangeListener;
    }

    public DefaultTreeModel getModel() {
        return model;
    }

    public void addArticle(Keyword keyword, Article article) {
        invokeOnEDT(() -> {
            KeywordNode node = findOrCreateNode(keyword);
            node.addArticle(article);
            nodeChangeListener.nodeUpdated(node);
        });
    }

    public List<Keyword> getAllKeywords() {
        List<Keyword> keywords = new ArrayList<>();
        invokeOnEDTAndWait(() -> addAllKeywords(keywords));
        return keywords;
    }

    public void merge(String firstKeyword, String secondKeyword) {
        invokeOnEDT(() -> {
            KeywordNode firstKeywordNode = findNode(firstKeyword);
            KeywordNode secondKeywordNode = findNode(secondKeyword);
            if (firstKeywordNode != null && secondKeywordNode != null) {
                if (firstKeywordNode.getKeyword().resultCount() > secondKeywordNode.getKeyword().resultCount()) {
                    merge(secondKeywordNode, firstKeywordNode);
                } else {
                    merge(firstKeywordNode, secondKeywordNode);
                }
            }
        });
    }

    private void addAllKeywords(List<Keyword> keywords) {
        visitKeywordNodes(node -> {
            keywords.add(node.getKeyword());
            return null;
        });
    }

    private KeywordNode findOrCreateNode(Keyword keyword) {
        KeywordNode existingNode = findNode(keyword);
        if (existingNode != null) {
            return existingNode;
        } else {
            KeywordNode node = new KeywordNode(keyword);
            model.insertNodeInto(node, keywordRootNode, keywordRootNode.getChildCount());
            nodeChangeListener.nodeInserted(node);
            return node;
        }
    }

    private KeywordNode findNode(Keyword keyword) {
        return visitKeywordNodes(node -> {
            if (keyword.equals(node.getKeyword())) {
                return node;
            } else {
                return null;
            }
        });
    }

    private KeywordNode findNode(String keyword) {
        return visitKeywordNodes(node -> {
            if (keyword.equals(node.getKeyword().text())) {
                return node;
            } else {
                return null;
            }
        });
    }

    private <T> T visitKeywordNodes(KeywordNodeVisitor<T> visitor) {
        Iterator<TreeNode> it = keywordRootNode.depthFirstEnumeration().asIterator();
        while (it.hasNext()) {
            KeywordNode node = (KeywordNode)it.next();
            if (node.getKeyword() == null) {
                continue;
            }
            T result = visitor.visit(node);
            if (result != null) {
                return result;
            }
        }
        return null;
    }

    private void merge(KeywordNode keywordNode, KeywordNode keywordTargetNode) {
        keywordTargetNode.merge(keywordNode);
        model.nodeChanged(keywordTargetNode);
        while (keywordNode.getChildCount() > 0) {
            KeywordNode childNode = (KeywordNode)keywordNode.getChildAt(0);
            model.removeNodeFromParent(childNode);
            model.insertNodeInto(childNode, keywordTargetNode, keywordTargetNode.getChildCount());
            nodeChangeListener.nodeInserted(childNode);
        }
        model.removeNodeFromParent(keywordNode);
        nodeChangeListener.nodeUpdated(keywordTargetNode);
    }

    private void invokeOnEDTAndWait(Runnable runnable) {
        if (EventQueue.isDispatchThread()) {
            runnable.run();
        } else {
            try {
                EventQueue.invokeAndWait(runnable);
            } catch (InterruptedException | InvocationTargetException e) {
                throw new RuntimeException(e);
            }
        }
    }

    private void invokeOnEDT(Runnable runnable) {
        if (EventQueue.isDispatchThread()) {
            runnable.run();
        } else {
            EventQueue.invokeLater(runnable);
        }
    }

    private interface KeywordNodeVisitor<T> {
        T visit(KeywordNode node);
    }

    public interface NodeChangeListener {
        void nodeInserted(KeywordNode node);
        void nodeUpdated(KeywordNode node);
    }
}
