# https://gradio.app/blocks-and-event-listeners/
# demo: http://10.147.17.2:7860/

import json
import os
import pdb
from operator import itemgetter
from pathlib import Path

import fire
import gradio as gr
import numpy as np
import openai
import pandas as pd
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm

from .zca import ZCA


class CVFDemo:
    # example values used in the paper
    dim_names = ["Create", "Collaborate", "Control", "Compete"]
    dim_seeds = [
        "We should adapt and innovate.",
        "We should empathize and collaborate.",
        "We should control and stabilize.",
        "We should respond swiftly and serve customers.",
    ]
    scales = {
        "External-Internal": {
            "Positive": ["Create", "Compete"],
            "Negative": ["Control", "Collaborate"],
        },
        "Flexible-Stable": {
            "Positive": ["Collaborate", "Create"],
            "Negative": ["Control", "Compete"],
        },
    }


class PathManager:
    def __init__(self, out_dir):
        self.out_dir = out_dir
        self.root_dir = os.path.dirname(os.path.abspath(__file__))
        self.sample_data_dir = Path(self.root_dir, "sample_data")
        print(self.sample_data_dir)


class Meaurement:
    def __init__(self, path_mgt):
        self.path_mgt = path_mgt
        self.model_name = None
        self.use_openAI = False
        self.model = None
        self.input_df = None
        self.embeddings = None
        self.col_names = None
        self.doc_col_name = None
        self.doc_id_col_name = None
        self.n_dims = None
        self.n_scales = None
        self.dim_embeddings = {}
        self.dim_queries = {}
        self.scale_embeddings = {}
        self.scale_definitions = {}

    def read_csv_cols(self, file_obj):
        self.input_df = pd.read_csv(file_obj[0].name)
        self.col_names = self.input_df.columns.tolist()
        print(f"Columns: {self.col_names}")
        return gr.Dropdown.update(
            choices=self.col_names, interactive=True
        ), gr.Dropdown.update(choices=self.col_names, interactive=True)

    def read_input_embedding(self, file_obj):
        self.embeddings = np.load(file_obj[0].name)
        if self.embeddings.shape[0] != len(self.input_df):
            msg = f"Error: The number of rows in the input data and the number of rows in the embedding matrix do not match! The shape of the embeddings is {self.embeddings.shape}, and the number of rows in the input data is {len(self.input_df)}. "
        else:
            msg = f"Precomputed embedding uploaded! \nThe shape of the embeddings is {self.embeddings.shape}. Proceed to the next tab to define the dimensions."
        return gr.Textbox.update(value=msg), gr.Button.update(visible=False)

    def set_doc_col(self, doc_col_name, doc_id_col_name):
        print(f"Setting doc_col_name to {doc_col_name}")
        if self.col_names is not None:
            if doc_col_name in self.col_names:
                self.doc_col_name = doc_col_name

        print(f"Setting doc_id_col_name to {doc_id_col_name}")
        if self.col_names is not None:
            if doc_id_col_name in self.col_names:
                self.doc_id_col_name = doc_id_col_name
        return gr.Textbox.update(value="Column names set!")

    def set_sbert(self, model_name):
        print(f"Setting SBERT model to {model_name}")
        self.model_name = model_name
        self.model = SentenceTransformer(model_name)

    def set_openai_api_key(self, api_key):
        print("Setting OpenAI API key")
        if api_key != "":
            try:
                openai.api_key = api_key
            except Exception as e:
                print(e)
                return gr.Textbox.update(
                    value="Invalid API key. Please try again.", visible=True
                )
        else:
            print("Using environment variable (OPENAI_API_KEY) for API key.")
            try:
                openai.api_key = os.environ["OPENAI_API_KEY"]
            except Exception as e:
                return gr.Textbox.update(
                    value="Invalid API key in OPENAI_API_KEY environment variable. Please try again.",
                    visible=True,
                )

    def toggle_embedding_model_visibility(
        self,
        embedding_model_dropdown,
    ):
        if embedding_model_dropdown.startswith("Sentence Transformers"):
            return (
                gr.Textbox.update(visible=True),
                gr.Textbox.update(visible=False),
                gr.Button.update(value="Set Embedding Model"),
            )
        if embedding_model_dropdown.startswith("OpenAI"):
            return (
                gr.Textbox.update(visible=False),
                gr.Textbox.update(visible=True),
                gr.Button.update(value="Set Embedding Model"),
            )

    def set_embedding_model(
        self,
        embedding_model_dropdown,
        sbert_model_textbox,
        openai_api_key,
    ):
        print(f"Setting embedding model to {embedding_model_dropdown}")
        if embedding_model_dropdown.startswith("Sentence Transformers"):
            self.set_sbert(sbert_model_textbox)
            self.use_openAI = False
        if embedding_model_dropdown.startswith("OpenAI"):
            self.set_openai_api_key(openai_api_key)
            self.use_openAI = True
        return gr.Button.update(value="Embedding model set!")

    def reset_set_emb_btn(self):
        return gr.Button.update(value="Set Embedding Model")

    def set_embedding_option(self, embedding_option):
        if embedding_option == "Upload Precomputed Embeddings":
            return gr.File.update(visible=True), gr.Button.update(visible=False)
        if embedding_option == "Embed Documents":
            return gr.File.update(visible=False), gr.Button.update(visible=True)

    def embed_texts(self, sentences, progress=gr.Progress()):
        if self.use_openAI is False:
            # use sentence_transformers to embed the text_col
            print("Embedding Using Sentence Transformers")
            sentence_embeddings = self.model.encode(
                sentences, show_progress_bar=True, normalize_embeddings=True
            )
        else:
            print("Embedding Using OpenAI")
            sentence_embeddings = []
            for sent in tqdm(sentences):
                response = openai.Embedding.create(
                    input=sent, model="text-embedding-ada-002"
                )
                embeddings = response["data"][0]["embedding"]
                sentence_embeddings.append(embeddings)
            sentence_embeddings = np.array(sentence_embeddings)
            # normalize embeddings
            sentence_embeddings = sentence_embeddings / np.linalg.norm(
                sentence_embeddings, axis=1, keepdims=True
            )

        return sentence_embeddings

    def embed_df(self):
        self.embeddings = self.embed_texts(self.input_df[self.doc_col_name])
        np.save(Path(self.path_mgt.out_dir, "embeddings.npy"), self.embeddings)
        return (
            gr.Textbox.update(
                value=f"Embedding completed! \nThe shape of the embeddings is {self.embeddings.shape}. You can download and save the embeddings below. Proceed to the next tab to define the dimensions."
            ),
            gr.File.update(visible=True),
            Path(self.path_mgt.out_dir, "embeddings.npy"),
        )

    def init_query_boxes(self, n_dims):
        # ! deprecated
        # toggle visibility of query boxes, dim name, search buttons, and results
        update_boxes = []
        self.n_dims = n_dims
        for i in range(10):
            if i < n_dims:
                update_boxes.extend(
                    [
                        gr.Textbox.update(visible=True),
                        gr.Textbox.update(visible=True),
                        gr.Button.update(visible=True),
                        gr.Textbox.update(visible=True),
                    ]
                )
            else:
                update_boxes.extend(
                    [
                        gr.Textbox.update(visible=False),
                        gr.Textbox.update(visible=False),
                        gr.Button.update(visible=False),
                        gr.Textbox.update(visible=False),
                    ]
                )
        return update_boxes

    def toggle_row_vis(self, n_rows):
        # toggle visbility of rows (dimensions tab)
        update_rows = []
        self.n_dims = n_rows
        for i in range(10):
            if i < n_rows:
                update_rows.append(gr.Row.update(visible=True))
            else:
                update_rows.append(gr.Row.update(visible=False))
        return update_rows

    def toggle_row_vis_scales(self, n_rows):
        # toggle visbility of rows (scales tab)
        update_rows = []
        self.n_scales = n_rows
        for i in range(10):
            if i < n_rows:
                update_rows.append(gr.Row.update(visible=True))
            else:
                update_rows.append(gr.Row.update(visible=False))
        return update_rows

    def semantic_search(self, query, n_results):
        # search for the most similar documents
        result_str = ""
        queries = query.split("\n")
        # remove empty queries or queries with only spaces
        queries = [q for q in queries if q.strip() != ""]
        query_embedding = self.embed_texts(queries)
        mean_vect_query = query_embedding.mean(axis=0)
        hits = util.semantic_search(
            mean_vect_query,
            self.embeddings,
            score_function=util.dot_score,
            top_k=n_results,
        )
        hit_ids = [h["corpus_id"] for h in hits[0]]
        hit_scores = [h["score"] for h in hits[0]]
        examples = itemgetter(*hit_ids)(self.input_df[self.doc_col_name])
        examples = list(examples)
        example_doc_ids = itemgetter(*hit_ids)(self.input_df[self.doc_id_col_name])
        example_doc_ids = list(example_doc_ids)
        assert len(hit_ids) == len(hit_scores) == len(examples)
        for i in range(len(hit_ids)):
            result_str += f"Document ID: {example_doc_ids[i]} \n"
            result_str += f"Score: {round(hit_scores[i], 3)} \n"
            result_str += "------------------------ \n"
            result_str += f"{examples[i]}\n"
            result_str += "------------------------ \n"
        return gr.Textbox.update(value=result_str)

    def save_dims(self, *dims_boxes):
        # first half of dims_boxes are dim names, second half are dim queries
        all_dim_names = dims_boxes[: len(dims_boxes) // 2]
        all_dim_queries = dims_boxes[len(dims_boxes) // 2 :]
        self.dim_embeddings = {}
        self.dim_queries = {}
        for i in range(self.n_dims):
            if all_dim_names[i].strip() != "":
                dim_name = all_dim_names[i].strip()
            else:  # if no name is given, use default name
                dim_name = f"Dimension_{i+1}"
            dim_queries = all_dim_queries[i].split("\n")  # dim queries
            # remove empty queries or queries with only spaces
            dim_queries = [q for q in dim_queries if q.strip() != ""]
            dim_embedding = self.embed_texts(dim_queries)
            mean_vect_dim = dim_embedding.mean(axis=0)
            self.dim_embeddings[dim_name] = mean_vect_dim
            self.dim_queries[dim_name] = dim_queries
        # save self.dim_queries to json file
        with open(Path(self.path_mgt.out_dir, "dimensions_queries.json"), "w") as f:
            json.dump(self.dim_queries, f)
        # update both positive and negative scale selectors
        return (
            [gr.Dropdown.update(choices=list(self.dim_queries.keys()))] * 20
            + [
                gr.Textbox.update(
                    visible=True,
                    value=f"Dimensions saved: {list(self.dim_queries.keys())}. You can download the json file below to keep a record of the final queries. Proceed to the next tab to define scales.",
                )
            ]
            + [gr.File.update(visible=True)]
            + [Path(self.path_mgt.out_dir, "dimensions_queries.json")]
        )

    def save_scales(self, *scale_boxes):
        # first 1/3 of scales boxes are names, second 1/3 are positive scales, third 1/3 are negative scales
        all_scale_names = scale_boxes[: len(scale_boxes) // 3]
        all_pos_scales = scale_boxes[len(scale_boxes) // 3 : 2 * len(scale_boxes) // 3]
        all_neg_scales = scale_boxes[2 * len(scale_boxes) // 3 :]
        self.scale_embeddings = {}
        self.scale_definitions = {}
        # all scale embeddings are average of positive scales subtracted by average of negative scales
        for i in range(self.n_scales):
            if all_scale_names[i].strip() != "":
                scale_name = all_scale_names[i].strip()
            else:  # if no name is given, use default name
                scale_name = f"Scale_{i+1}"
            # save scale definitions
            """
            example scale_definitions json format: 
            scales = {
                "External-Internal": {
                    "Positive": ["Create", "Compete"],
                    "Negative": ["Control", "Collaborate"],
                },
                "Flexible-Stable": {
                    "Positive": ["Collaborate", "Create"],
                    "Negative": ["Control", "Compete"],
                },
            }
            """
            self.scale_definitions[scale_name] = {
                "Positive": all_pos_scales[i],
                "Negative": all_neg_scales[i],
            }
            scale_embedding_pos = []
            scale_embedding_neg = []
            if len(all_pos_scales[i]) > 0:
                for pos_scale in all_pos_scales[i]:
                    scale_embedding_pos.append(self.dim_embeddings[pos_scale])
            if len(all_neg_scales[i]) > 0:
                for neg_scale in all_neg_scales[i]:
                    scale_embedding_neg.append(self.dim_embeddings[neg_scale])
            if (len(scale_embedding_pos) > 0) & (len(scale_embedding_neg) > 0):
                # take averages and difference
                scale_embedding = np.stack(scale_embedding_pos).mean(axis=0) - np.stack(
                    scale_embedding_neg
                ).mean(axis=0)
            elif len(scale_embedding_pos) > 0:  # only positive scales
                scale_embedding = np.stack(scale_embedding_pos).mean(axis=0)
            elif len(scale_embedding_neg) > 0:  # only negative scales
                scale_embedding = -np.stack(scale_embedding_neg).mean(axis=0)
            else:  # no scales
                pass
            self.scale_embeddings[scale_name] = scale_embedding

        # save self.scales to json file
        with open(Path(self.path_mgt.out_dir, "scale_definitions.json"), "w") as f:
            json.dump(self.scale_definitions, f)
        return (
            [
                gr.Textbox.update(
                    visible=True,
                    value=f"Scales saved: {list(self.scale_definitions.keys())}. You can download the json file below to keep a record of the scale definitions. Proceed to the next tab to measure using semantic projection.",
                )
            ]
            + [gr.File.update(visible=True)]
            + [Path(self.path_mgt.out_dir, "scale_definitions.json")]
        )

    def measure_docs(self, whitening: str):
        scale_measures = []
        for scale in list(self.scale_definitions.keys()):
            measures = (
                util.pytorch_cos_sim(self.embeddings, self.scale_embeddings[scale])
                .cpu()
                .numpy()
            )
            # change measures to one dimensional array
            measures = measures.reshape(measures.shape[0])
            scale_measures.append(measures)

        scale_measures = np.stack(scale_measures).T
        if whitening == "ZCA":
            # whitening
            trf = ZCA().fit(scale_measures)
            scale_measures = trf.transform(scale_measures)
        scale_measures = scale_measures.round(4)
        # output as csv to File
        scale_measures = pd.DataFrame(scale_measures)
        scale_measures.columns = list(self.scale_definitions.keys())
        scale_measures[self.doc_id_col_name] = self.input_df[self.doc_id_col_name]
        scale_measures = scale_measures[
            [self.doc_id_col_name] + list(self.scale_definitions.keys())
        ]
        Path("measure_output").mkdir(parents=True, exist_ok=True)
        Path(self.path_mgt.out_dir, "measurement_output.csv")
        scale_measures.to_csv(
            Path(self.path_mgt.out_dir, "measurement_output.csv"), index=False
        )
        return (
            gr.Textbox.update(
                visible=True,
                value="Measurement completed. Download the results below. ",
            ),
            gr.File.update(visible=True),
            Path(self.path_mgt.out_dir, "measurement_output.csv"),
        )

    def load_example_dataset(self):
        # tab 1
        self.input_df = pd.read_csv(
            Path(self.path_mgt.sample_data_dir, "sample_text.csv")
        )
        self.embeddings = np.load(Path(self.path_mgt.sample_data_dir, "sample_emb.npy"))
        self.set_sbert("all-MiniLM-L6-v2")
        self.doc_col_name = "text"
        self.doc_id_col_name = "doc_id"
        self.use_openAI = False

        # tab 2
        all_dim_name_boxes_updates = []
        all_dim_seed_boxes_updates = []
        for i in range(10):
            if i < len(CVFDemo().dim_names):
                all_dim_name_boxes_updates.append(
                    gr.Textbox.update(value=f"{CVFDemo().dim_names[i]}")
                )
                all_dim_seed_boxes_updates.append(
                    gr.Textbox.update(value=f"{CVFDemo().dim_seeds[i]}")
                )
            else:
                all_dim_name_boxes_updates.append(gr.Textbox.update(value=""))
                all_dim_seed_boxes_updates.append(gr.Textbox.update(value=""))

        # tab 3
        all_scale_name_boxes_updates = []
        all_scale_pos_selector_updates = []
        all_scale_neg_selector_updates = []

        demo_scale_names = list(CVFDemo().scales.keys())
        for i in range(10):
            if i < 2:
                all_scale_name_boxes_updates.append(
                    gr.Textbox.update(value=demo_scale_names[i])
                )
                all_scale_pos_selector_updates.append(
                    gr.Textbox.update(
                        value=CVFDemo().scales[demo_scale_names[i]]["Positive"]
                    )
                )
                all_scale_neg_selector_updates.append(
                    gr.Textbox.update(
                        value=CVFDemo().scales[demo_scale_names[i]]["Negative"]
                    )
                )
            else:
                all_scale_name_boxes_updates.append(None)
                all_scale_pos_selector_updates.append(None)
                all_scale_neg_selector_updates.append(None)

        return (
            [
                gr.Row.update(visible=True),
                Path(self.path_mgt.sample_data_dir, "sample_text.csv"),
                Path(self.path_mgt.sample_data_dir, "sample_emb.npy"),
            ]
            + all_dim_name_boxes_updates
            + all_dim_seed_boxes_updates
            + all_scale_name_boxes_updates
            + all_scale_pos_selector_updates
            + all_scale_neg_selector_updates
            + [gr.Slider.update(value=4), gr.Slider.update(value=2)]
        )


def run_gui(out_dir="measure_output/", mode="local", username=None, password=None):
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    path_mgt = PathManager(out_dir=out_dir)
    with gr.Blocks(title="SPAR") as demo:
        m = Meaurement(path_mgt=path_mgt)
        gr.Markdown(
            "### SPAR: Semantic Projection with Active Retrieval (Research Preview)"
        )
        all_dim_name_boxes = []
        all_search_query_boxes = []
        all_search_btns = []
        all_search_results = []
        all_rows_dims = []
        all_rows_scale = []
        all_scale_name_boxes = []
        all_scale_pos_selector = []
        all_scale_neg_selector = []
        with gr.Row():
            with gr.Column(scale=1):
                example_btn = gr.Button(value="💡 Load Example Dataset and Scales")
            with gr.Column(scale=8):
                gr.Markdown(
                    value="""* SPAR is a Python package and a web interface for measuring short text documents using semantic projection.
                    * The package is part of the manuscript ISR-2022-128 (under review). It is still considered a research prototype and under active development. 
                    * The source code is available on [GitHub](https://github.com/ISR2022128/SPAR_measure) under GPLv3 license.""",
                    label="",
                )
        example_row = gr.Row(visible=False, variant="panel")
        with example_row:
            with gr.Column(scale=2):
                example_load_msgbox = gr.Markdown(
                    visible=True,
                    label="",
                    value="""
                    
                    💡 __Example dataset and embeddings loaded. You may download the example dataset and embeddings on the right.__  
                    💡 __Do not change the settings in Tab 1. Proceed to Tab 2.__  
                    💡 __If you want to use your own dataset, refresh the page and upload a CSV file in Tab 1.__""",
                )
            with gr.Column(scale=1):
                example_file_download = gr.File(
                    visible=True, label="Download Example File"
                )
                example_emb_download = gr.File(
                    visible=True, label="Download Example Embeddings"
                )

        with gr.Tab("1. Upload File and Embed"):
            # read csv file and select columns
            gr.Markdown(
                value=" 📖 Upload a CSV file that contains the text to be measured and a column that contains the document ID. Alternatively, you can click the Load Example Dataset and Scales button to explore the tool with an included sample dataset (4000 Facebook posts) and pre-defined Competing Values Framework (CVF) dimension and scales.",
                label="",
            )
            with gr.Row(variant="panel"):
                input_file = gr.File(
                    file_count=1,
                    file_types=[".csv"],
                    label="Input CSV File",
                )

                # read_col_btn = gr.Button(value="Read CSV File Columns")
                doc_col_selector = gr.Dropdown(
                    choices="",
                    label="Select Text Column",
                    interactive=False,
                )
                doc_id_col_selector = gr.Dropdown(
                    choices="",
                    label="Select Document ID Column",
                    interactive=False,
                )
                doc_id_col_btn = gr.Button(value="Confirm Column Selections")
                doc_id_col_btn.click(
                    fn=m.set_doc_col,
                    inputs=[doc_col_selector, doc_id_col_selector],
                    outputs=doc_id_col_btn,
                )

            gr.Markdown(
                value=""" 📖 Select an embedding method. You can use Sentence Transformers to embed the text locally; in this case you can use the default model name (all-MiniLM-L6-v2) or [any other models](https://www.sbert.net/docs/pretrained_models.html).
                You can also use the [OpenAI API](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings) to embed (text-embedding-ada-002 model). You can get an API key [here](https://platform.openai.com/account/api-keys). """,
            )

            with gr.Row(variant="panel"):
                # embedding options
                embedding_model_dropdown = gr.Dropdown(
                    choices=[
                        "Sentence Transformers (Local)",
                        "OpenAI text-embedding-ada-002 (API Key Required)",
                    ],
                    label="Select an embedding model:",
                    interactive=True,
                    value="Sentence Transformers (Local)",
                    multiselect=False,
                )
                openai_api_key = gr.Textbox(
                    placeholder="We strongly recommed revoking your API key after public testing as our server does not provide any security guarantees to prevent potential security breaches.",
                    label="OpenAI API Key (leave blank if key is set as Environment Variable)",
                    value="",
                    visible=False,
                    interactive=True,
                )
                sbert_model_textbox = gr.Textbox(
                    label="Sentence Transformers Model Name",
                    value="all-MiniLM-L6-v2",
                    interactive=True,
                    visible=True,
                )
                set_emb_btn = gr.Button(value="Set Embedding Model", visible=True)
                set_emb_btn.click(
                    fn=m.set_embedding_model,
                    inputs=[
                        embedding_model_dropdown,
                        sbert_model_textbox,
                        openai_api_key,
                    ],
                    outputs=set_emb_btn,
                )
            gr.Markdown(
                value=" 📖 Click the Embed Documents button. \n Alternatively, you can upload a numpy array file (.npy or .npz) with precomputed document embeddings. The file should be generated using numpy.save() with the shape (n_docs, embedding_dim). It must be embedded using the same embedding model as selected above, because the queries will be embedded using the same model. "
            )
            with gr.Row(variant="panel"):
                upload_emb_choice = gr.Radio(
                    choices=["Embed Documents", "Upload Precomputed Embeddings"],
                    label="Embedding Options",
                    visible=True,
                    interactive=True,
                )

                embed_btn = gr.Button("Embed Documents", visible=False)
                input_embedding_uploader = gr.File(
                    file_count=1,
                    file_types=[".npz", ".npy"],
                    label="Precomputed Embeddings",
                    visible=False,
                )
                emb_result_txtbox = gr.Textbox(
                    value="",
                    label="Embedding Progress (check console for progress updates)",
                )
                emb_results_file = gr.File(visible=False)

            upload_emb_choice.change(
                fn=m.set_embedding_option,
                inputs=upload_emb_choice,
                outputs=[input_embedding_uploader, embed_btn],
            )

            input_file.change(
                fn=m.read_csv_cols,
                inputs=input_file,
                outputs=[doc_col_selector, doc_id_col_selector],
            )

            input_embedding_uploader.change(
                fn=m.read_input_embedding,
                inputs=input_embedding_uploader,
                outputs=[emb_result_txtbox, embed_btn],
            )
            embedding_model_dropdown.change(
                fn=m.toggle_embedding_model_visibility,
                inputs=[
                    embedding_model_dropdown,
                ],
                outputs=[sbert_model_textbox, openai_api_key, set_emb_btn],
            )
            sbert_model_textbox.change(
                fn=m.reset_set_emb_btn,
                outputs=[set_emb_btn],
            )
            # embedding
            openai_api_key.change(
                fn=m.set_openai_api_key,
                inputs=openai_api_key,
                outputs=[emb_result_txtbox],
            )
            embed_btn.click(
                fn=m.embed_df,
                outputs=[emb_result_txtbox, emb_results_file, emb_results_file],
            )
        with gr.Tab("2. Define Dimensions and Semantic Search"):
            gr.Markdown(
                value=" 📖 Move the sliders to set the number of dimensions and the number of results in each round of semantic search. ",
                label="",
            )
            n_dim_slider = gr.Slider(
                1,
                10,
                step=1,
                value=4,
                interactive=True,
                label="Number of dimensions",
            )
            n_results_slider = gr.Slider(
                10,
                200,
                step=5,
                value=10,
                interactive=True,
                label="Number of results in search",
            )

            gr.Markdown(
                value=" 📖 Enter the names of the dimensions and seed search queries. Then click the Search Dimension button to search for relevant documents in the corpus. Copy, add, and edit the relevant documents to query box to conduct next round of search. Each query should be in its own line. ",
                label="",
            )

            for i in range(10):
                with gr.Row(variant="panel") as a_row:
                    all_rows_dims.append(a_row)
                    if i < 4:
                        a_row.visible = True
                    else:
                        a_row.visible = False
                    # 3 columns: query box, search button, search results
                    with gr.Column(scale=4, min_width=400):
                        if i < 4:
                            # default 4 dimensions visible
                            all_dim_name_boxes.append(
                                gr.Textbox(
                                    lines=1,
                                    max_lines=1,
                                    interactive=True,
                                    placeholder="e.g. " + CVFDemo().dim_names[i],
                                    value="",
                                    label=f"Dimension {i + 1} Name (Optional)",
                                    visible=True,
                                ),
                            )
                            all_search_query_boxes.append(
                                gr.Textbox(
                                    lines=5,
                                    interactive=True,
                                    label=f"Query (Seed) Sentences for Dimension {i + 1}. One per line. (Required)",
                                    value="",
                                    placeholder="e.g. " + CVFDemo().dim_seeds[i],
                                    visible=True,
                                )
                            )
                        else:
                            all_dim_name_boxes.append(
                                gr.Textbox(
                                    lines=1,
                                    max_lines=1,
                                    placeholder=None,
                                    interactive=True,
                                    label=f"Dimension {i + 1} Name (Optional)",
                                    value="",
                                    visible=True,
                                ),
                            )
                            all_search_query_boxes.append(
                                gr.Textbox(
                                    lines=5,
                                    interactive=True,
                                    label=f"Query (Seed) Sentences for Dimension {i + 1}. One per line. (Required)",
                                    value="",
                                    placeholder=f"Dimension {i + 1} seed sentences, one per line",
                                    visible=True,
                                )
                            )

                    with gr.Column(scale=1):
                        # search button and result textbox
                        all_search_btns.append(
                            gr.Button(
                                f"Search Dimension {i + 1}",
                                visible=True,
                            )
                        )

                    with gr.Column(scale=4, min_width=400):
                        all_search_results.append(
                            gr.Textbox(
                                value="",
                                label=f"Search Results for Dimension {i + 1}. Copy and paste relevant sentences into the query box to the left.",
                                visible=True,
                            )
                        )

            n_dim_slider.change(
                fn=m.toggle_row_vis, inputs=n_dim_slider, outputs=all_rows_dims
            )
            for box in all_search_query_boxes:
                box.change(fn=m.toggle_row_vis, inputs=n_dim_slider)

            for dim_i, btn in enumerate(all_search_btns):
                btn.click(
                    fn=m.semantic_search,
                    inputs=[
                        all_search_query_boxes[dim_i],
                        n_results_slider,
                    ],
                    outputs=all_search_results[dim_i],
                )
            gr.Markdown(
                value=" 📖 After defining the dimensions and final queries, click the Embed Queries and Save Dimensions button to embed the queries and save the dimension definitions. Each dimension must contain at least one query.",
                label="",
            )
            save_dim_button = gr.Button("Embed Queries and Save Dimensions")
            dim_define_results = gr.Textbox(visible=False, label="")
            dimension_def_file_download = gr.File(visible=False)

        with gr.Tab("3. Define Scales"):
            gr.Markdown(
                value=" 📖 Move the sliders to set the number of scales.",
                label="",
            )
            n_scale_slider = gr.Slider(
                1,
                10,
                step=1,
                value=2,
                interactive=True,
                label="Number of scales",
            )

            gr.Markdown(
                value=" 📖 Enter the names of the scales and select the relevant dimensions. Then click the Save Scales button to compute the scale embedding vectors. Each scale is computed by first averaging the dimension embedding vectors in the positive and negative dimensions, and then taking the difference between the two vectors. Each scale must contain at least one positive dimension. ",
                label="",
            )
            for i in range(10):
                with gr.Row(variant="panel") as a_row:
                    all_rows_scale.append(a_row)
                    if i >= 2:
                        a_row.visible = False
                        with gr.Column(scale=2, min_width=400):
                            all_scale_name_boxes.append(
                                gr.Textbox(
                                    lines=1,
                                    max_lines=1,
                                    interactive=True,
                                    placeholder=None,
                                    value="",
                                    label=f"Scale {i + 1} Name",
                                    visible=True,
                                ),
                            )
                    else:
                        a_row.visible = True
                        with gr.Column(scale=2, min_width=400):
                            all_scale_name_boxes.append(
                                gr.Textbox(
                                    lines=1,
                                    max_lines=1,
                                    interactive=True,
                                    placeholder="e.g. "
                                    + list(CVFDemo().scales.keys())[i],
                                    value="",
                                    label=f"Scale {i + 1} Name",
                                    visible=True,
                                ),
                            )
                        # 3 columns: scale name, positive dimension, negative dimension
                    with gr.Column(scale=4):
                        all_scale_pos_selector.append(
                            gr.Dropdown(
                                interactive=True,
                                multiselect=True,
                                label=f"Positive Dimensions for Scale {i + 1} (Required)",
                                value=None,
                                visible=True,
                            )
                        )
                    with gr.Column(scale=4):
                        all_scale_neg_selector.append(
                            gr.Dropdown(
                                interactive=True,
                                multiselect=True,
                                label=f"Negative Dimensions for Scale {i + 1} (Optional)",
                                value=None,
                                visible=True,
                            )
                        )
            save_dim_button.click(
                fn=m.save_dims,
                inputs=all_dim_name_boxes + all_search_query_boxes,
                outputs=all_scale_pos_selector
                + all_scale_neg_selector
                + [dim_define_results]
                + [dimension_def_file_download] * 2,
            )
            n_scale_slider.change(
                fn=m.toggle_row_vis_scales,
                inputs=n_scale_slider,
                outputs=all_rows_scale,
            )
            for box in all_scale_name_boxes:
                box.change(fn=m.toggle_row_vis_scales, inputs=n_scale_slider)
            for box in all_scale_pos_selector:
                box.change(fn=m.toggle_row_vis_scales, inputs=n_scale_slider)
            for box in all_scale_neg_selector:
                box.change(fn=m.toggle_row_vis_scales, inputs=n_scale_slider)
            save_scale_button = gr.Button("Save Scales")
            scale_define_results = gr.Textbox(visible=False, label="")
            scale_def_file_download = gr.File(visible=False)

            save_scale_button.click(
                fn=m.save_scales,
                inputs=all_scale_name_boxes
                + all_scale_pos_selector
                + all_scale_neg_selector,
                outputs=[scale_define_results] + [scale_def_file_download] * 2,
            )
        with gr.Tab("4. Measurement"):
            gr.Markdown(
                value=" 📖 Click the Measure Documents Using Semantic Projection button to score each document. The output file will contain the document ID and the scores for each scale.",
                label="",
            )
            with gr.Row(variant="panel"):
                whitening_radio_btn = gr.Radio(
                    choices=["None", "ZCA"],
                    value="ZCA",
                    label="Whitening: used to decorrelate the scores of the semantic projection. We recommend using ZCA in most applications, especially if the scales are theoretically orthogonal.",
                    visible=True,
                )
            measure_button = gr.Button("Measure Documents Using Semantic Projection")
            measure_result = gr.Textbox(visible=False, label="")

            measure_results_file = gr.File(visible=False)
            measure_button.click(
                fn=m.measure_docs,
                inputs=whitening_radio_btn,
                outputs=[measure_result, measure_results_file, measure_results_file],
            )
        example_btn.click(
            fn=m.load_example_dataset,
            outputs=[
                example_row,
                example_file_download,
                example_emb_download,
            ]
            + all_dim_name_boxes
            + all_search_query_boxes
            + all_scale_name_boxes
            + all_scale_pos_selector
            + all_scale_neg_selector
            + [n_dim_slider, n_scale_slider],
        )
    demo.queue(concurrency_count=3)
    if mode == "public":
        demo.launch(
            share=True,
            auth=(username, password),
            server_name="0.0.0.0",
            favicon_path="sample_data/favicon.png",
        )
    elif mode == "intranet":
        demo.launch(
            server_name="0.0.0.0",
            auth=(username, password),
            favicon_path="sample_data/favicon.png",
        )
    elif mode == "local":
        demo.launch(favicon_path="sample_data/favicon.png")
    else:
        print("Invalid mode. Please choose from 'public', 'intranet', or 'local'.")


if __name__ == "__main__":
    fire.Fire(run_gui)
