はじめまして!機械学習エンジニアリングチームで長期インターンをしているhikaru.tです。
この記事では、長期インターンにて取り組んだタスクについて、紹介させていただきたいと思います!
取り組んだタスク
pixivには、アップロードした画像を見て、おすすめタグを表示してくれるタグ推薦機能がありますが、このタグ推薦にはまだ改善の余地が多くあります。
それは、閾値を超えるタグがなかった場合、推薦を行わないため、まずタグ自体が出てこないということです。また、タグが出てきたとしても、あまり精度の良い推薦ができていないのが現状です。
そこで、今回のインターンで割り振られたのが「タグ推薦モデルの作成」というタスクでした。
インターン目標
現状のモデルには、以下のような課題があります。
- 使用モデルのアーキテクチャが開発当初のものになっている
- 再学習の余地が大きい
- おすすめが表示には閾値を超える必要がある
そこで、これらの課題を解決するようなモデルを構築することが、今回のインターンの目標でした。
タグ推定モデルの構築
まず、マルチラベル分類を用いた推定を用いました。
pixivのタグデータセットの形を見ると、これが最もシンプルな推定方法になります。 シンプルイズベストって言いますよね。 後述するように、色々試してみましたが、これが一番性能良かったです。
データ取得
- ラベルはpixivのタグデータから、”XXXXUsers入り”, “XXXX生誕祭”などの投稿時には付くはずがなかったり、特定の年号に依存してしまうタグを除いて、上位n件を取得してくる。
- それらのタグがついている画像のうち、ブックマーク数が一定以上のものを取得してくる。
タグの数が不均衡である問題の対策
初めに3行でまとめると、以下のような結果になりました。
- 人気タグとそうでないタグを比べると1000倍近くの件数差がある。
- 様々な対策を行ったが、効果のあるものは少なかった。
- ついている画像が少ないタグを省いたところ、Recallが改善した1。
pixivのタグは、冪乗則におおむね従って件数が減少していくロングテール型の分布となっています。
そのため、インバランスドなタグを平等に推薦する仕組みが大切になってきます。
そこで、各クラス間の件数が大きく異なっているデータセットに対応するためのLossやネットワーク構造をいくつか試しましたが、いろいろな理由で採用できませんでした。
1. Focal Loss
インバランスドなデータセットに対応する方法として最もポピュラーな手法
$0, 1$ のどちらかに振り分けられないデータ(≒少数データ)を重視するLossとして知られている。
今回のデータセットでは、学習を行なった結果、全ての結果が$0$になってしまうという問題が発生した。
理由としては、pixivのタグデータセットは、ついているべきタグがついていないことにより、確信度が低い状態になることが多く2、全てを$0$に近づけることがLossを最小化する上での最適解になってしまったと推察できる。
2. Influence-balanced Loss
決定境界をなめらかにすることによって、多数クラスの精度を犠牲にして、少数クラスの精度を上げる手法
元々Multi Class Classification用のLossだったものを、Multi Label用に雑に拡張して使ったが、抽出された特徴量を使った正則化がうまく働かなかった。きちんと拡張すれば動くかもしれないが、そこの理論を詰めるよりも、他のアプローチを試した方がいいという判断をした。
3. XML-CNN
低次元の表現に圧縮してから、高次元のタグ表現空間にマッピングするもので、特徴量の抽出を容易にする・非常に多くのタグに対して対応させることができるようにするための手法
少数データを捉えづらくなるという欠点はあるものの、低次元の表現空間にマッピングすることで、より簡単に全体的な特徴を捉えやすくなる。
ある程度の効果は見られたものの、学習・推論コストの増大に見合う効果は得られなかったため、適用しないことにした。
モデル構築
- バックボーンネットワークにはEfficient Net b0を利用した。
- 最終層の活性化関数をSigmoidにする一般的な多ラベル分類のネットワーク構造である。
学習セッティング
項目 | パラメータ |
---|---|
Dataset size | 418969枚 |
Train size: Validation_size: Test_size | 0.81:0.09:0.1 |
Batch size | 64 |
Early stopping | Stop when the recall is not updated for 3 epochs |
Loss function | Binary Cross Entropy |
Optimizer | SGD |
Initial learning rate | 5e-2 |
Weight_decay | 1e-4 |
Scheduler | ReduceLROnPlateau |
Factor | 0.2 |
Patience | 2 |
学習結果
全体としては40 epochくらいで収束して、Recallは30.7%となりました。
平均で3.9個のタグがついているデータセットなので、均すと各イラストについて1個以上は正解できている計算になるので、そこまで悪くない結果だと思います。
また、つくべきタグがついていない問題については、Test Datasetで評価しきれないので、それも数値を見る上で頭に入れておいた方がいいかもしれません。
その他試した手法
キャラクタ名を予測してから、そのキャラクタに付きがちなタグを共起情報から推定して、推薦するアルゴリズムや、Zero Shot Learningを利用したアプローチなどを試しました。
その結果、キャラクタ名を予測するアプローチについては一定の精度を得ることができたのですが、キャラクタが予測対象にない時の挙動が不安定なことや、pixivにアップロードされる画像にはオリジナルが多いなどの問題がありました。
タグベースの推論では、オリジナルキャラクタに対しても、汎用的なタグ(女の子・猫耳など)をつけることができることもあり、検討の結果、タグベースの推論を行う方が良いという結論になりました。
実画像に対する結果
「Re:ゼロから始まる異世界生活」から、サンタコスのレムさんのイラストを自分で描いて、どんな結果が出力されたかを確かめました。
結論から言うと以下のような結果になりました。
タグベース推薦
予測結果は上の画像のようになりました。 適切なタグは6/10ですが、同じようなタグが複数出力されていますね。
また、ラム・レムは同時に描かれることが多いせいか、どちらがレムでどちらがラムか区別がついていないような印象を受けます。 やはり表記揺れが出てしまって、実効的に予測できているタグが少なくなってしまっているのが残念です。
タグ予測Web APIの構築
次に、この予測モデルをWebサーバーから叩けるようにするため、APIの構築を行いました。
技術選定
Pythonのwebフレームワークには、
- Django
- Flask
- FastAPI
などがあります。
メンターさんに勧めていただいたこと&調べてみたらいい感じだったことから、今回はFastAPIを採用することにしました。
良いと思った点を、より細かく分けて説明すると、
- 今回Front Endは開発しないので、Swagger等を使ったデバッグをすることになる。その時、SwaggerとReDocを自動生成してくれるFastAPIは便利
- ステートレスなHTTPエンドポイントを作ることができればいいので、Djangoはリッチすぎる
- Flaskを触ったことがあったので、Flaskに近い書き方ができると嬉しい
- エンドポイントの定義が、型ヒントを使って直感的にできる
といった点が挙げられます。
モデルの事前読み込み
今回使用したモデルは数十MBしかないので、そこまで読み込みに時間はかかりません。しかし、毎回読み込んでいては、レスポンスタイムに影響が出ますし、非効率的です。
そこで、APIの初回呼び出し時のみモデルを読み込むようにして、レスポンスタイムを高速化します。
そのために、FastAPIオブジェクトを利用します。FastAPIオブジェクトは、Stateオブジェクトを持っており、状態を保持することができます。
そのため、アプリ起動時にStateオブジェクトにモデルをロードし、終了時に破棄するようなコードを書いてあげれば、毎回モデルを読み込む必要がなくなります。
具体的には、以下のようにFastAPIオブジェクトを生成している部分で、appオブジェクトをクロージャに包んで、イベントハンドラを設定してあげればいいです。
from fastapi import FastAPI from api.ml.models import MLModels def _startup_model(app: FastAPI) -> None: app.state.models = MLModels() def _shutdown_model(app: FastAPI) -> None: app.state.models = None def start_app_handler(app: FastAPI) -> Callable: def startup() -> None: _startup_model(app) return startup def stop_app_handler(app: FastAPI) -> Callable: def shutdown() -> None: _shutdown_model(app) return shutdown def create_app() -> FastAPI: app = FastAPI() app.add_event_handler("startup", start_app_handler(app)) app.add_event_handler("shutdown", stop_app_handler(app)) return app app = create_app()
テスト
今回、一番苦労したのが、このテスト周りでした。
初めに、ルーティング部分のテストを行ったのですが、ここでModelをstateに登録したことにより難しくなったポイントが出てきました。それは、PythonのMockerはその仕組み上、オブジェクトが生成された後ではPatchできないということです。
Mocker.patchは、import句の中身を入れ替えるような挙動をします。
ここで思い出してほしいのが、
import Hoge.Fuga as Fuga fuga = Hoge.Fuga import Hige.Fuga as Fuga fuga.say()
のようなコードの場合、fuga.say()で実行されるのは、初めに読み込んだHoge.Fugaオブジェクトのsayメゾットだということです。
そのため、Patchするのであれば、クラスがインスタンス化される前にPatchしなければなりません。しかし、startup_handerで初期化を行なっているので、テストコードが呼び出されるときには、モデルのインスタンス化が済んでおり、Patchしても反映されないという壁に当たりました。
その解決のため、DIライブラリを導入して、コードを書き直し、依存性を外部から注入するようにしました。
その結果、テストはうまくいくようになったのですが、コードの量が増え、読みづらくなってしまいました。小規模なプロジェクトでDIをすると、かかる手間に比べて、得られるメリットが少なくなってしまいがちですが、今回のプロジェクトでも同様の結果となってしまったわけです。
コードレビューでも、もしテストのためだけにDIをしているのであれば、少し大袈裟なので、別の方法を考えてみても良いかもしれないというアドバイスをいただきました。
そこで、少し迂遠な方法にはなりますが、Routerモジュール内でFastAPIオブジェクトからStateオブジェクトを取り出すGetterを定義し、そのGetterをPatch/DIするという手法をとることにしました。
コードとしては以下のようになります。
from api.ml.models import MLModels def getMLModels() -> Callable: def getter(app) -> MLModels: return app.state.models return getter recommend_router = APIRouter() @recommend_router.post( "/hoge", ) async def tag_recommend( request: Request, file: UploadFile = File(...), getter=Depends(getMLModels), ): models = getter(request.app) ...
テストコード
from tests.mock.mock_ml_models import MockMLModels @pytest.fixture(scope="module") def client() -> Generator: """APIのTestClientを返すFixture""" def getMockMLModels() -> Callable: def getter(app) -> MLModels: return MockMLModels() return getter app.dependency_overrides[getMLModels] = getMockMLModels with TestClient(app) as c: yield c def test_hoge_api(client: TestClient, rgb_png_img_file: str) -> None: """Hoge APIのレスポンスの形式を検証する。 """ with open(rgb_png_img_file, "rb") as img_file: response = client.post( "/hoge", files={"file": (rgb_png_img_file, img_file, "image/png")}, ) assert response.status_code == HTTPStatus.OK assert response.json() == { "hoge": "fuga" }
CIの構築
せっかくテストを書いたので、どうせならCIも組んでしまえということで、GitLab CIも組ませていただきました。具体的にかけた処理としては、テスト・Format Checker・Linter・MyPy(型チェック)になります。
今回のCIでは、特にFormatter/LinterをかけてCommitするとかはしなかったので、Pre-commitスクリプトで対応しました。
まとめ
今回は、長期インターンで取り組んだ、タグ推薦モデルの構築と、そのWeb API化について話させていただきました。大量のデータを使って、実データでモデルをトレーニングする機会は少ないので、貴重な経験をさせていただきました。
現状のモデルの3つの課題については、
- 使用モデルのアーキテクチャが古い
- VGG16からEfficientNetにすることができたため、アーキテクチャは新しくなった。
- 再学習できていない
- 最新のデータを使ってトレーニングしたため、暫くは再学習しなくても良くなった。また、再学習も、スクリプトを利用してすぐに訓練データの生成・前処理等を行うことができるようにした。
- 閾値を超えないとおすすめが表示されない
- 予測結果上位10件を常に表示することにより、解決した。
というように、解決することができました。
また、本番環境デプロイに向けたコードレビューをしていただいたので、業務コードを書くうえで、気をつけなければならないこと、Bad Practice等に気づくことができました。
- ここでのRecallはMacro/Recallを意味しています。(各クラスについてRecallを計算した後に、全体の平均をとっています。)適切な評価方法なのか、という点についてですが、1. pixivのデータセットにはつくべきタグがついていないものが多いため、Precisionで評価すると、不当に低い評価となる可能性が高い 2. 負例が非常に多いことから、Accuracyは正当な評価基準足り得ない(実際、全てのタグについて0と予測した場合でも、99.7%以上のAccuracyが出てしまう)3. AUC, ROC曲線については, 時間がなくて検証できなかった という背景があります。Recallについても、全てを1と予測することで100%となってしまう為、あまり良い指標とは言えませんが、体感的にモデルの品質をよく表していると感じたため採用しました。↩
- 本来であればPositiveと判定されるべきLabelがNegativeと判定されてしまうため、学習回数を重ねても出力が安定せず、全体的にNegativeに近い値を取ってしまいます。↩