Source code for pythainlp.augment.lm.wangchanberta

# -*- coding: utf-8 -*-
# Copyright (C) 2016-2023 PyThaiNLP Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import (
    CamembertTokenizer,
    pipeline,
)
import random
from typing import List

model_name = "airesearch/wangchanberta-base-att-spm-uncased"


[docs]class Thai2transformersAug:
[docs] def __init__(self): self.model_name = "airesearch/wangchanberta-base-att-spm-uncased" self.target_tokenizer = CamembertTokenizer self.tokenizer = CamembertTokenizer.from_pretrained( self.model_name, revision="main" ) self.tokenizer.additional_special_tokens = [ "<s>NOTUSED", "</s>NOTUSED", "<_>", ] self.fill_mask = pipeline( task="fill-mask", tokenizer=self.tokenizer, model=f"{self.model_name}", revision="main", ) self.MASK_TOKEN = self.tokenizer.mask_token
[docs] def generate(self, sentence: str, num_replace_tokens: int = 3): self.sent2 = [] self.input_text = sentence sent = [ i for i in self.tokenizer.tokenize(self.input_text) if i != "▁" ] if len(sent) < num_replace_tokens: num_replace_tokens = len(sent) masked_text = self.input_text for i in range(num_replace_tokens): replace_token = [ sent.pop(random.randrange(len(sent))) for _ in range(1) ][0] masked_text = masked_text + self.MASK_TOKEN self.sent2 += [ str(j["sequence"]).replace("<s> ", "").replace("</s>", "") for j in self.fill_mask(masked_text) if j["sequence"] not in self.sent2 ] masked_text = self.input_text return self.sent2
[docs] def augment(self, sentence: str, num_replace_tokens: int = 3) -> List[str]: """ Text Augment from wangchanberta :param str sentence: thai sentence :param int num_replace_tokens: number replace tokens :return: list of text augment :rtype: List[str] :Example: :: from pythainlp.augment.lm import Thai2transformersAug aug=Thai2transformersAug() aug.augment("ช้างมีทั้งหมด 50 ตัว บน") # output: ['ช้างมีทั้งหมด 50 ตัว บนโลกใบนี้', 'ช้างมีทั้งหมด 50 ตัว บนสุด', 'ช้างมีทั้งหมด 50 ตัว บนบก', 'ช้างมีทั้งหมด 50 ตัว บนนั้น', 'ช้างมีทั้งหมด 50 ตัว บนหัว'] """ self.sent2 = [] self.sent2 = self.generate(sentence, num_replace_tokens) return self.sent2