package jlama;

import com.github.tjake.jlama.util.Downloader;

import javax.swing.*;
import javax.swing.FocusManager;
import java.awt.*;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.nio.file.Path;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CancellationException;
import java.util.function.Consumer;

public class ModelInstaller {
    public static final Path MODEL_CACHE_PATH = Path.of(System.getProperty("user.home"), ".jprofiler16", "jlama");
    public static final String MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct";
    private static final long MB_FACTOR = 1024 * 1024;

    private boolean downloadStarted;
    private ProgressInfo lastProgressInfo;

    public static void main(String[] args) throws IOException {
        new ModelInstaller().install(progressInfo -> {
            String line = "Downloading " + progressInfo.filename() + ": " + progressInfo.percent() + "% of " + progressInfo.totalSizeMb() + " MB\r";
            System.out.print(line);
            System.out.flush();
        }, () -> false);
    }

    private SwingWorker<Boolean, ProgressInfo> worker;

    public void withInstallation(Runnable runnable) {
        worker = new SwingWorker<>() {
            ProgressDialog progressDialog;

            @Override
            protected Boolean doInBackground() throws Exception {
                return install(this::publish, this::isCancelled);
            }

            @Override
            protected void process(List<ProgressInfo> chunks) {
                if (!chunks.isEmpty()) {
                    if (progressDialog == null) {
                        progressDialog = new ProgressDialog(getParentWindow(), () -> cancel(true));
                        SwingUtilities.invokeLater(() -> progressDialog.setVisible(true));
                    }
                    ProgressInfo progressInfo = chunks.getLast();
                    int totalSizeMb = progressInfo.totalSizeMb();
                    String downloadingText = "Downloading " + progressInfo.filename();
                    if (totalSizeMb > 0) {
                        progressDialog.setProgress(downloadingText + " (" + totalSizeMb + " MB)", progressInfo.percent());
                    } else {
                        progressDialog.setText(downloadingText);
                    }
                }
            }

            @Override
            protected void done() {
                if (progressDialog != null) {
                    progressDialog.setVisible(false);
                    progressDialog.dispose();
                }
                try {
                    if (get()) {
                        runnable.run();
                    }
                } catch (CancellationException ignored) {
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }

            }
        };
        worker.execute();
        try {
            worker.get(); // Keep the process alive
        } catch (Exception ignored) {
        }
    }

    private boolean install(Consumer<ProgressInfo> progressInfoConsumer, CancelSource cancelSource) {
        Downloader downloader = new Downloader(MODEL_CACHE_PATH.toString(), MODEL_NAME);
        downloader.withProgressReporter((filename, sizeDownloaded, totalSize) -> {
            if (cancelSource.isCancelled()) {
                exit();
            }
            if (!downloadStarted) {
                try {
                    EventQueue.invokeAndWait(() -> {
                        Dashboard.initLaF();
                        int result = JOptionPane.showOptionDialog(getParentWindow(),
                                "This demo requires a local LLM model.\n\n" +
                                MODEL_NAME + " will now downloaded from Hugging Face.\n" +
                                "It will be cached for future sessions.\n\n" +
                                "Proceed or exit?", Dashboard.WINDOW_TITLE, JOptionPane.DEFAULT_OPTION, JOptionPane.QUESTION_MESSAGE, null, new String[] {"Proceed", "Exit"}, "Exit");
                        if (result != 0) {
                            exit();
                        }
                    });
                } catch (InterruptedException | InvocationTargetException e) {
                    throw new RuntimeException(e);
                }
                downloadStarted = true;
            }
            int mb = Math.round(1f * totalSize / MB_FACTOR);
            int percent = Math.round(100f * sizeDownloaded / totalSize);
            ProgressInfo progressInfo = new ProgressInfo(filename, percent, mb);
            if (!Objects.equals(progressInfo, lastProgressInfo)) {
                progressInfoConsumer.accept(progressInfo);
                lastProgressInfo = progressInfo;
            }
        });

        try {
            File file = downloader.huggingFaceModel();
            if (downloadStarted) {
                System.out.println();
                System.out.println("Successfully downloaded model file " + file);
            }
            return true;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private  void exit() {
        System.exit(-1);
    }

    private Window getParentWindow() {
        Window activeWindow = FocusManager.getCurrentManager().getActiveWindow();
        if (activeWindow != null) {
            return activeWindow;
        } else {
            JFrame tempFrame = new JFrame(Dashboard.WINDOW_TITLE);
            tempFrame.setUndecorated(true);
            tempFrame.setVisible(true);
            tempFrame.setLocationRelativeTo(null);
            return tempFrame;
        }
    }

    private record ProgressInfo(String filename, int percent, int totalSizeMb) {
    }

    private interface CancelSource {
        boolean isCancelled();
    }
}
