def get_js_tokenizer(model_identifier):
    import js
    from pyodide.ffi import to_js

    assert "gpt" in model_identifier, "JS tokenizer only supports GPT models."

    class JSBridgedTokenizer:
        """ Custom tokenizer to be used only in a browser environment. This tokenizer only supports GPT tokenization. """
        def __init__(self):
            self.bos_token_id = 50256
            self.eos_token_id = 50256
            self._vocab = None

        @property
        def vocab_size(self):
            return len(self.vocab)

        @property
        def vocab(self):
            if self._vocab is None:
                self._vocab = js.get_vocab().to_py()
            return self._vocab

        def convert_tokens_to_string(self, tokens):
            return js.convert_tokens_to_string_gpt(to_js(tokens))

        def tokenize(self, s):
            unpack = False
            if type(s) is not list:
                s = [s]
                unpack = True
            tokens = [js.tokenize_gpt_toks(se).to_py() for se in s]
            
            if unpack:
                return tokens[0]
            else:
                return tokens
         
        def decode(self, input_ids):
            # set typed array type of input_ids to int
            return str(js.detokenize_gpt(to_js([int(i) for i in input_ids])))

        def __call__(self, s: str):
            unpack = False
            if type(s) is not list:
                s = [s]
                unpack = True
            input_ids = [[int(v) for v in js.tokenize_gpt(se)] for se in s]
            
            if unpack:
                return {"input_ids": input_ids[0]}
            else:
                return {"input_ids": input_ids}
    
    return JSBridgedTokenizer()

global special_token_mappings
special_token_mappings = {}
global reverse_special_token_mappings
reverse_special_token_mappings = {}

class LMQLTokenizer:
    def __init__(self, tokenizer_impl):
        self.tokenizer_impl = tokenizer_impl

    @property
    def vocab_size(self):
        return self.tokenizer_impl.vocab_size

    @property
    def bos_token_id(self):
        return self.tokenizer_impl.bos_token_id
    
    @property
    def eos_token_id(self):
        return self.tokenizer_impl.eos_token_id

    @property
    def vocab(self):
        return self.tokenizer_impl.vocab

    def convert_tokens_to_string(self, tokens):
        return self.tokenizer_impl.convert_tokens_to_string(tokens)

    def tokenize(self, s):
        tokens = []
        for s in self.chunk_out_by_tags(s, tokenize=False):
            if s.startswith("lmql:"):
                tokens.append(s)
            else:
                tokens += self.tokenizer_impl.tokenize(s)
        return tokens
        
    def decode(self, input_ids):
        s = ""
        for chunk in self.chunk_out_by_special_ids(input_ids):
            if type(chunk) is str:
                s += chunk
            else:
                s += self.tokenizer_impl.decode(chunk)
        return s

    def __call__(self, s: str):
        input_ids = []
        unpack = False
        if type(s) is not list:
            s = [s]
            unpack = True

        for seq in s:
            chunk_input_ids = []
            for chunk in self.chunk_out_by_tags(seq):
                if type(chunk) is int:
                    chunk_input_ids.append(chunk)
                else:
                    chunk_input_ids += self.tokenizer_impl(chunk)["input_ids"]
            input_ids.append(chunk_input_ids)
        
        if unpack:
            return {"input_ids": input_ids[0]}
        else:
            return {"input_ids": input_ids}
    
    def special_token_id(self, identifier):
        global special_token_mappings
        global reverse_special_token_mappings
        
        if identifier not in special_token_mappings:
            if len(special_token_mappings) == 0:
                # offset vocabulary IDs by at least the next decimal power of 10
                offset = 10 ** (len(str(self.vocab_size)))
                special_token_mappings[identifier] = offset
                reverse_special_token_mappings[offset] = identifier
            else:
                next_id = max(special_token_mappings.values()) + 1
                special_token_mappings[identifier] = next_id
                reverse_special_token_mappings[next_id] = identifier
        return special_token_mappings[identifier]
    
    def chunk_out_by_special_ids(self, input_ids, tokenize=True):
        global reverse_special_token_mappings
        c = []
        for i in input_ids:
            if i in reverse_special_token_mappings.keys():
                if len(c) > 0:
                    yield c
                c = []
                yield "<" + reverse_special_token_mappings[i] + "/>"
            else:
                c.append(i)
        yield c
    
    def chunk_out_by_tags(self, s, tokenize=True):
        # filter out all special tokens <lmql:.../>
        import re
        segments = []
        offset = 0
        for m in re.finditer(r"<lmql:(.*?)\/>", s):
            segments.append(s[offset:m.start()])
            if tokenize:
                segments.append(self.special_token_id("lmql:" + m.group(1)))
            else:
                segments.append("lmql:" + m.group(1))
            offset = m.end()
        segments.append(s[offset:])
        return segments

def load_tokenizer(model_identifier):
    import os

    # check environment of USE_JS_TOKENIZER
    if "LMQL_BROWSER" in os.environ:
        return LMQLTokenizer(get_js_tokenizer(model_identifier))

    from transformers import AutoTokenizer
    import torch

    # first try to load pickled tokenizer from cache (faster)
    import pickle
    import pathlib

    cache_dir = pathlib.Path.home() / ".cache" / "lmql"
    cache_dir.mkdir(parents=True, exist_ok=True)
    cache_identifier = model_identifier.replace("/", "-")
    cache_path = cache_dir / f"tokenizer-{cache_identifier}.pkl"

    if cache_path.exists():
        with open(cache_path, "rb") as f:
            return LMQLTokenizer(pickle.load(f))
    else:
        from transformers import AutoTokenizer
        t = AutoTokenizer.from_pretrained(model_identifier)
        with open(cache_path, "wb") as f:
            pickle.dump(t, f)
        return LMQLTokenizer(t)

if __name__ == "__main__":
    import sys

    model_identifier = sys.argv[1]
    t = load_tokenizer(model_identifier)

    to_tokenize = sys.argv[2]

    if to_tokenize.startswith("["):
        import json
        to_tokenize = json.loads(to_tokenize)
        print(str([t.decode(torch.tensor(to_tokenize))])[1:-1])
    else:
        res = t(to_tokenize)
        print(res)
        print(t.convert_ids_to_tokens(res["input_ids"]))
        n = 0
        result = ""
        for t,id in sorted(t.vocab.items(), key=lambda p: p[1]):
            # contains digit
            digits = "0123456789"
            if len(t) < 4 and any(c in digits for c in t):
                print(t,id)
                n += 1
                result += f""""{t}","""
        print(n)
        print(result)