Flairでサクッと文章分類を試してみる【RoBERTa】

Flairについて

FlairはPythonで書かれた自然言語処理の機械学習フレームワークです。

非常に簡潔にコードが書けるので、機械学習モデルをサクッと試すには良い選択肢だと思います。

GitHub - flairNLP/flair: A very simple framework for state-of-the-art Natural Language Processing (NLP)
A very simple framework for state-of-the-art Natural Language Processing (NLP) - GitHub - flairNLP/flair: A very simple framework for state-of-the-art Natural L...

この記事では、Flairを使用してlivedoorのニュース記事を分類するモデルを構築していきます。

データセットについて

今回はデータとして以下のlivedoorのニュース記事を使用します。

ダウンロード - 株式会社ロンウイット
DOWNLOADS

このデータは株式会社ロンウイットさんが、収集・配布してくださっているデータです。

ありがたく使用させていただきます。

データ整形について

上記URLに配置されているldcc-20140209.tar.gzというファイルを今回は使用します。

wget https://www.rondhuit.com/download/ldcc-20140209.tar.gz
tar zxvf ldcc-20140209.tar.gz

圧縮ファイルを展開すると、ディレクトリ構成が以下のようになっています。

text/
├── dokujo-tsushin
├── it-life-hack
├── kaden-channel
├── livedoor-homme
├── movie-enter
├── peachy
├── smax
├── sports-watch
└── topic-news

dokujo-tsushinのフォルダの中を見てみるとdokujo-tsushin-4778030.txt、dokujo-tsushin-4778031.txtという風に、1記事あたり1ファイルに分かれて配置されています。

今回はFlairを使って学習をさせていくのですが、Flairは文章分類を行う際にfastTextで扱うのと同様の形式のファイルを入力に使用することができます。

すなわち、__label__クラス名\t文章という形式で整理していけばOKです。

pythonのスクリプトでデータを整形していきます。

import glob
import os
from collections import defaultdict
from tqdm import tqdm
import random
random.seed(777)

# 全ファイルのリストを取得
file_list = glob.glob("./text/**/*.txt")

contents_dict = defaultdict(list)
for f in tqdm(file_list):
    file_name = os.path.basename(f)
    # 記事ファイル以外のものをスキップする。
    if file_name in ["LICENSE.txt","CHANGES.txt","README.txt"]:
        continue

    # ディレクトリ名をラベルとして扱う
    label = f.split("/")[-2]

    txt = open(f,"r",encoding="utf8")
    content = ""
    for i,line in enumerate(txt):
        # 最初の2行は不要な情報
        if i < 2:
            continue
        if line.strip() == "":
            continue

        # 全角空白やタブなどを排除
        line = ''.join(line.split())

        content += f"{line.strip()} "
    
    content = content.strip()
    # ラベルごとに文章を格納する。
    contents_dict[label].append(content)

# train,test,devのファイルをそれぞれ作成する。
train_output_file = open("train_livedoor.tsv","w",encoding="utf8")
test_output_file = open("test_livedoor.tsv","w",encoding="utf8")
dev_output_file = open("dev_livedoor.tsv","w",encoding="utf8")

for label,content_list in contents_dict.items():
    for content in content_list:
        # ランダムに凡そtrain:test:dev = 8:1:1 となるように振り分ける。
        randam_int = random.randint(0,9)
        if randam_int < 8:
            train_output_file.write(f"__label__{label}\t{content}\n")
        if randam_int == 8:
            test_output_file.write(f"__label__{label}\t{content}\n")
        if randam_int == 9:
            dev_output_file.write(f"__label__{label}\t{content}\n")

上記のスクリプトの実行でtrain_livedoor.tsv,test_livedoor.tsv,dev_livedoor.tsvの3つのファイルができました。

これらのファイルを使って学習を行っていきます。

学習

Google Colaboratoryにて学習を行っていきます。

Google Driveの任意の場所に先ほど作成した3ファイルを配置しておきましょう。

まずは、flairをインストールします。

conlluというモジュールの開発が進み、依存関係で一部不整合があったので少し古いバージョンを入れています。

!pip install flair
!pip install conllu==4.4.2

今回使うモジュールをインポートします。

from flair.datasets.document_classification import ClassificationCorpus
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from pathlib import Path
import torch

FlairのClassificationCorpusを使って、今回使用するコーパスをロードします。

fastTextのデータ形式で入力ファイルを作成した場合はlabel_type=’topic’が使用できます。

corpus = ClassificationCorpus(Path('/content/drive/MyDrive/blog/flair_text_classification'),
                              label_type='topic',
                              test_file='test_livedoor.tsv',
                              dev_file='dev_livedoor.tsv',
                              train_file='train_livedoor.tsv',
                              )

続いては使用する文章ベクトルを定義していきます。

今回はFlair経由で「nlp-waseda/roberta-base-japanese」を使用させていただきます。

nlp-waseda/roberta-base-japanese

RoBERTaの標準的な入力の大きさは512tokenとなりますが、livedoorの文章の長さは512tokenでは収まり切りません。

1文章あたりを512tokenを限度としてその後ろを切り捨ててしまっても良いのですが、今回は全文使用することにします。

そのため、allow_long_sequences=Trueを指定し、cls_pooling=’mean’を指定しています。

cls_poolingは’cls’,’mean’,’max’の3種類を指定できますが、allow_long_sequences=Trueを指定する場合は’mean’か’max’のどちらかを指定することになります。

これはbert系のモデルでは[cls]という特殊トークンの分散表現を文章ベクトルと見立てることができますが、モデルの許容できる長さを超えた文章では使用できないからです。

※例えば文章全体が1000tokenだった場合には、bertの推論は2回行われて[cls]の分散表現は2つできてしまう。

一方で’mean’,’max’の場合は、token毎の分散表現を平均・最大でまとめてしまうため、長文に対しても使用することができます。

その他、メモリと学習時間節約のためfine_tune=Falseを指定しています。

embedding = TransformerDocumentEmbeddings('nlp-waseda/roberta-base-japanese',
                                          cls_pooling='mean',
                                          allow_long_sequences=True,
                                          fine_tune=False,
                                          model_max_length=512
                                          )

続いてはFlairのTextClassifierクラスのインスタンスを作成します。

classifier = TextClassifier(embedding,
                            label_dictionary=corpus.make_label_dictionary(label_type='topic'),
                            multi_label=False,
                            label_type='topic')

最後にTrainerを定義して、学習を回すだけです。

trainer = ModelTrainer(classifier,corpus)
trainer.train('./', 
              max_epochs=5,
              learning_rate=0.01,
              mini_batch_size=8,
              mini_batch_chunk_size=1,
              train_with_dev=False,
              monitor_test=False,
              optimizer=torch.optim.AdamW
              )
Results:
- F-score (micro) 0.7986
- F-score (macro) 0.7818
- Accuracy 0.7986

By class:
                precision    recall  f1-score   support

          smax     0.7615    0.9651    0.8513        86
  sports-watch     0.8218    0.9540    0.8830        87
dokujo-tsushin     0.8636    0.8539    0.8588        89
 kaden-channel     0.6354    0.7722    0.6971        79
   movie-enter     0.8864    0.8966    0.8914        87
        peachy     0.7528    0.7976    0.7746        84
  it-life-hack     0.9000    0.5056    0.6475        89
    topic-news     0.8615    0.8000    0.8296        70
livedoor-homme     0.7586    0.5000    0.6027        44

      accuracy                         0.7986       715
     macro avg     0.8046    0.7828    0.7818       715
  weighted avg     0.8086    0.7986    0.7921       715

Flairは最後に上記のようにprecision,recall,f1-scoreを綺麗にまとめて出力してくれます。

非常に簡潔なコードを書くだけで約80%くらいの精度のモデルを作成することができました。

Flairは少々癖はあるかと思いますが、上記のように簡単な記述でサクッと学習を試すには使い勝手が良いフレームワークだと思います。

文章分類だけではなく、固有表現抽出や関係性抽出などの情報抽出系のタスクにも使えるので、いずれは記事にしようと思います。

コメント

タイトルとURLをコピーしました