うしのおちちの備忘録

AtCoderや日記、自然言語処理などについて書きます。

Rust✖️WebAssemblyでブラウザ上で固有表現抽出する

はじめに

最近sudachi.rsをwasmにビルドして使ってるツイートを見て、自然言語処理の結構いろんなことが同じようにできるのではと思い、手始めに固有表現抽出してみることにしました。

https://twitter.com/vbkaisetsu/status/1412761328943460355?s=21

使用ライブラリ

形態素解析

sudachi.rsdを使います。ただ、2021/9/15日現在のものではビルドは通りますが初期化時にどうしてもパニックしてしまうため、適当に動いてそうなコミットまで戻って使いました。

固有表現抽出

今回はRust✖️WebAssemblyの入門ということで、簡単にCRFを使うことにします。 Rustではpure-rustで書かれたCRFクレートが公開されています。ただ、学習には対応していないようなので、学習だけはcrfsuiteのRustバインディングを使う必要があります。とはいっても、メンテナは同じみたいで、実際使い勝手はほとんど変わらないので特に問題はなさそうです。

モデルの学習

先ほど紹介したcrfsuite-rsを使って、固有表現抽出用のモデルを学習していきます。 今回はたまたま手元に病名抽出の学習データがあるので、それを使っていきます。

データの準備

とりあえずIOB2タグ形式の学習データを作っていきます。別にxmlみたいな生テキストから形態素解析するところまで学習コードでやってもいいですが、IOB2タグ形式のデータが手元にあってのでそれベースでやっていきます。

IOB2タグ形式のファイル例です。1行1トークンで(トークン、ラベル)列を並べて行って、空行で文を区切ります。今回は文字種や品詞を特徴量として使わないのでそれらの情報は入れていません。

今治 O
タオル O
は O
愛媛 B-Location
の O
誇り O
だ O

みかん O
は O
おいしい O

ちなみにbratなどを使っていてデータ形式がオフセットになっている場合、Spacyのiob_utilsが便利です。生テキストとオフセットを入れると、タグが返ってきます。現状BILUO形式しか対応していませんが、BILUOからBIOへの変換の関数も用意されているのでそれを使えばオッケーです。

データの前処理

crfsuite-rsは、1文ごとに素性列Vec<Vec<Attribute>>とラベル列Vec<String>を用意して学習に使います。Attributenamevalueを持つ構造体で、nameに素性を、valueに素性の値を入れます。valueは特に素性に重み付けしたいわけでなければ全部1で良さそうです。

とりあえずIOB2タグ形式のデータを読み込みます。Rusut初心者なので書き方が拙いのは勘弁してください...

use std::fs::File;
use std::io::{self, BufRead, BufReader};

fn load_dataset(path: &str) -> (Vec<Vec<Vec<Attribute>>>, Vec<Vec<String>>){
    let file = File::open(path).unwrap();
    let reader = BufReader::new(file);
    let mut x_sent: Vec<String> = Vec::new();
    let mut y_sent: Vec<String> = Vec::new();

    let mut x_all: Vec<Vec<Vec<Attribute>>> = Vec::new();
    let mut y_all: Vec<Vec<String>> = Vec::new();

    for line in reader.lines() {
        let l = line.unwrap();
        if l == "" {
            if x_sent.len() != 0 {
                let attributes = extract_features(&x_sent);
                x_all.push(attributes);
                y_all.push(y_sent);

                x_sent = Vec::new();
                y_sent = Vec::new();
            }
            continue;
        }

        let mut iter = l.split_whitespace();
        let x = match iter.next() {
            Some(x) => x.to_string(),
            None => {
                "None".to_string()
            }
        };
        let y = match iter.next() {
            Some(x) => x.to_string(),
            None => {
                "None".to_string()
            }
        };
        x_sent.push(x);
        y_sent.push(y);
        assert_eq!(iter.next(), None);
    }
    (x_all, y_all)
}

読み込んだデータを素性に変換していきます。今回は精度を追い求めているわけではないので、前後一文字とBi-gramだけを素性に使います。

use crfsuite::Attribute;
fn extract_features(sent: &Vec<String>) -> Vec<Vec<Attribute>> {
    let mut result: Vec<Vec<Attribute>> = Vec::new();
    for idx in 0..sent.len() {
        let mut attributes: Vec<Attribute> = Vec::new();
        attributes.push(Attribute::new(&sent[idx], 1.0));

        let pre_word = match idx {
            0 => "BOS",
            _ => &sent[idx-1],
        };

        let post_word = match idx {
            n if n >= sent.len() - 1 => "EOS",
            _ => &sent[idx+1],
        };

        attributes.push(Attribute::new(format!("-1_{}", pre_word), 1.0));
        attributes.push(Attribute::new(format!("{}_{}", pre_word, &sent[idx]), 1.0));

        attributes.push(Attribute::new(format!("+1_{}", post_word), 1.0));
        attributes.push(Attribute::new(format!("{}_{}", &sent[idx], post_word), 1.0));

        result.push(attributes);
    }
    return result;
}

モデルの学習

先ほど読み込んだデータを元にモデルを学習します。モデルの学習自体は特に難しいことはなさそうですね。ドキュメントは充実してないですが、適当にテストコードみてパクりました。 学習はCPUでも十分におわります。学習データは50万文程度でしたが、半日放置すれば学習は終わっていました(放置していた&時間測ってないせいで正確な学習時間はわかりませんが...)。

use std::path::Path;
use crfsuite::{Trainer, Attribute, Algorithm, GraphicalModel};

fn main() {
    let train_path = Path::new("./resource/train.iob");
    let valid_path = Path::new("./resource/valid.iob");
    let (x_all, y_all) = load_dataset(train_path.to_str().unwrap());

    let mut trainer = Trainer::new(true);
    trainer
        .select(Algorithm::LBFGS, GraphicalModel::CRF1D)
        .unwrap();

    for (x_sent, y_sent) in x_all.iter().zip(y_all.iter()) {
        trainer.append(x_sent, y_sent, 0i32).unwrap();
    }
    trainer.train("models/hatena/crfsuite", -1i32).unwrap();
}

推論

生テキストを読み込んで、学習させたCRFに推論させていきます。推論には冒頭で触れたCRFクレートであるcrfs-rsを使います。使い方は学習時に使ったものと同じです。ラベルの遷移に制限をつけたかったですが、そういう機能は見つけられませんでした。例えばOからIには絶対に遷移しないので、遷移行列に罰則を加えたりしたかったですよね。

ともあれ、一緒にwasm-bindingjs-sysを使ってWebAssembly用にコードを書きます。基本的には#[wasm_bindgen]をつけた関数などがWebAssemblyで使えるようにコンパイルされます。

include_bytes!でモデルと辞書を埋め込んで、WebAssemblyで呼び出せるように書いていきます。

use crfs::Model;
use wasm_bindgen::prelude::*;

#[wasm_bindgen]
pub struct NER {
    tagger: Model<'static>,
    tokenizer: Tokenizer<'static>,
}

#[wasm_bindgen]
impl NER {
    pub fn new() -> Self {
        // data for crf
        let buf = include_bytes!("/path/to/models/hatena.crfsuite");

        // data for tokenizer
        let bytes = include_bytes!("/path/to/sudachi.rs/src/resources/system.dic");
        Self {
            tagger: Model::new(buf).unwrap(),
            tokenizer: Tokenizer::new(bytes),
        }
    }
}

学習時にはすでにトークナイズ済みのデータを使っていましたが、推論では生テキストを扱うので、形態素解析のための関数も必要です。トークナイズはwasmから呼び出す必要はないので#[wasm_binding]はつけません。

impl NER {
    pub fn tokenize(&self, s: &str) -> Vec<String> {
        let morpheme_list = self.tokenizer.tokenize(&s.to_string(), &Mode::B, false);

        let tokens: Vec<String> = morpheme_list
            .iter()
            .map(|m| String::from(m.surface()))
            .collect();
        return tokens;
    }
}

いよいよ、生テキストを読み込んでIOB2タグ系列を返す関数を作ります。IOB2タグだけ返しても仕方ないので、トークン列も一緒に返すことにします。javascriptで使うのでjs-sysというクレートでjavascript対応の変数に変換しています。

use js_sys::{Array, JsString, global, Object};

#[wasm_bindgen]
impl NER {
    pub fn tag(&mut self, sent: &str) -> Array {
        let xseq = self.tokenize(sent);

        let mut tagger = self.tagger.tagger().unwrap();
        let attributes = extract_features(&xseq);
        let res = tagger.tag(&attributes).unwrap();
        let yseq: Array = res.iter()
            .map(|s| JsString::from(s.to_string()))
            .collect();
        let xseq: Array = xseq
            .iter()
            .map(|s| JsString::from(s.clone()))
            .collect();
        Array::of2(&xseq, &yseq)
    }

WebAssembly

Cargo.toml

Cargo.tomlはこんな感じです。別に変わったことはありませんが、crate-type = ["cdylib"]にすることだけ注意です。

[package]
name = "ner-wasm"
version = "0.1.0"
authors = ["ujiuji1259 <suzzz428@gmail.com>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
crate-type = ["cdylib"]

[dependencies]
js-sys = "0.3.52"
wasm-bindgen = "0.2.74"
crfs = "0.1.3"

[dependencies.sudachi]
path = "/path/to/sudachi.rs"

ビルド

wasm-packといういい感じにjavascriptから使えるwasmにビルドしてくれる便利なツールがあるので、それを使います。インストールはcargo install wasm-packするだけです。

--target webにしておくと、javascriptから直接触れる感じのコードにしてくれるらしいです。web-packとか必要ない人はこちらで十分ぽいですが、正直よくわかってないです。

# install
cargo install wasm-pack

# build
wasm-pack build --target web

Javascriptから呼び出す

ビルドしたらpkgというディレクトリにいろいろ生成されますが、{パッケージ名}.jsというモジュールからビルドしたWebAssemblyを全部呼び出せます。

<script type="module">
    import init, {NER} from './pkg/ner_wasm.js';

    function ner() {
        let wasm = await init();
        model = NER.new();
        let tokens = model.tag("急性骨髄性白血病とは、血液のがんの一種です。");

        console.log(tokens[0]);
        console.log(tokens[1]);
    }
</script>

結果

がんはとれてないですが、急性骨髄性白血病はちゃんととれてるっぽいです。

f:id:kuroneko1259:20210816011121p:plain
出力結果

おわりに

RustでCRFによる固有表現抽出モデルを学習して、wasmとしてブラウザ上で動かしてみました。 固有表現抽出でも例に漏れずBERTなどが猛威を奮っていますが、シビアに精度が求められない場合はCRFでも十分固有表現抽出できたりします。CPUでも十分学習できますし、そんなにモデルサイズも大きくないです(素性の取り方によると思いますが、今回のモデルが67MBくらいです)。ブラウザで動作することを考えれば十分感がありませんか?

何に使えるかはわかりませんが、同じようなことがBERTとかでもできると思うので、いろいろ試していこうと思います。