Metadata-Version: 2.1
Name: g-zhpr
Version: 0.1.3
Summary: 
Author: Philip Huang
Author-email: p208p2002@gmail.com
Requires-Python: >=3.8,<4.0
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Requires-Dist: openpyxl (>=3.0.10,<4.0.0)
Requires-Dist: pandas (>=1.5.1,<2.0.0)
Requires-Dist: scikit-learn (>=1.1.3,<2.0.0)
Requires-Dist: seqeval (>=1.2.2,<2.0.0)
Requires-Dist: transformers (>=4.24.0,<5.0.0)
Description-Content-Type: text/markdown

# 中文標點符號標注
訓練資料集: [p208p2002/ZH-Wiki-Punctuation-Restore-Dataset](https://github.com/p208p2002/ZH-Wiki-Punctuation-Restore-Dataset)

共計支援6種標點符號: ， 、 。 ？ ！ ； 

## 安裝
```bash
# pip install torch pytorch-lightning
pip install zhpr
```

## 使用
```python
from zhpr.predict import DocumentDataset,merge_stride,decode_pred
from transformers import AutoModelForTokenClassification,AutoTokenizer
from torch.utils.data import DataLoader

def predict_step(batch,model,tokenizer):
        batch_out = []
        batch_input_ids = batch

        encodings = {'input_ids': batch_input_ids}
        output = model(**encodings)

        predicted_token_class_id_batch = output['logits'].argmax(-1)
        for predicted_token_class_ids, input_ids in zip(predicted_token_class_id_batch, batch_input_ids):
            out=[]
            tokens = tokenizer.convert_ids_to_tokens(input_ids)
            
            # compute the pad start in input_ids
            # and also truncate the predict
            # print(tokenizer.decode(batch_input_ids))
            input_ids = input_ids.tolist()
            try:
                input_id_pad_start = input_ids.index(tokenizer.pad_token_id)
            except:
                input_id_pad_start = len(input_ids)
            input_ids = input_ids[:input_id_pad_start]
            tokens = tokens[:input_id_pad_start]
    
            # predicted_token_class_ids
            predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids]
            predicted_tokens_classes = predicted_tokens_classes[:input_id_pad_start]

            for token,ner in zip(tokens,predicted_tokens_classes):
                out.append((token,ner))
            batch_out.append(out)
        return batch_out

if __name__ == "__main__":
    window_size = 256
    step = 200
    text = "維基百科是維基媒體基金會運營的一個多語言的百科全書目前是全球網路上最大且最受大眾歡迎的參考工具書名列全球二十大最受歡迎的網站特點是自由內容自由編輯與自由著作權"
    dataset = DocumentDataset(text,window_size=window_size,step=step)
    dataloader = DataLoader(dataset=dataset,shuffle=False,batch_size=5)

    model_name = 'p208p2002/zh-wiki-punctuation-restore'
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    new_tokens = [chr(x) for x in range(ord('A'), ord('Z') + 1)]
    new_tokens.append(" ")
    tokenizer.add_special_tokens({'additional_special_tokens':new_tokens})
    model.resize_token_embeddings(len(tokenizer))
    
    model_pred_out = []
    for batch in dataloader:
        batch_out = predict_step(batch,model,tokenizer)
        for out in batch_out:
            model_pred_out.append(out)
        
    merge_pred_result = merge_stride(model_pred_out,step)
    merge_pred_result_deocde = decode_pred(merge_pred_result)
    merge_pred_result_deocde = ''.join(merge_pred_result_deocde)
    print(merge_pred_result_deocde)
```
```
維基百科是維基媒體基金會運營的一個多語言的百科全書，目前是全球網路上最大且最受大眾歡迎的參考工具書，名列全球二十大最受歡迎的網站，特點是自由內容、自由編輯與自由著作權。
```
