web-dev-qa-db-ja.com

ケガスバージョンのhuggingfaceトランスフォーマー用のbiobert

https://github.com/dmis-lab/biobert/issues/98 にも投稿されています)

こんにちは、huggingfaceトランスフォーマー(バージョン2.4.1)を使用してbiobertをケラスレイヤーとしてロードする方法を知っている人はいますか?私はいくつかの可能性を試しましたが、どれもうまくいきませんでした。私が見つけたすべてはpytorchバージョンの使い方ですが、kerasレイヤーバージョンに興味があります。以下は2つの試みです(biobertファイルをフォルダー "biobert_v1.1_pubmed"に保存しました)。

試み1:

biobert_model = TFBertModel.from_pretrained('bert-base-uncased')
biobert_model.load_weights('biobert_v1.1_pubmed/model.ckpt-1000000')

エラーメッセージ:

AssertionError: Some objects had attributes which were not restored:
    : ['tf_bert_model_4/bert/embeddings/Word_embeddings/weight']
    : ['tf_bert_model_4/bert/embeddings/position_embeddings/embeddings']
   (and many more lines like above...)

試み2:

biobert_model = TFBertModel.from_pretrained("biobert_v1.1_pubmed/model.ckpt-1000000", config='biobert_v1.1_pubmed/bert_config.json')

エラーメッセージ:

NotImplementedError: Weights may only be loaded based on topology into Models when loading TensorFlow-formatted weights (got by_name=True to load_weights).

助けてくれてありがとう! huggingfaceのトランスフォーマーライブラリーに関する私の経験はほとんどゼロです。次の2つのモデルもロードしようとしましたが、pytorchバージョンしかサポートしていないようです。

3
dmollaaliod

少し遅れるかもしれませんが、この問題に対するそれほどエレガントな修正は見つかりませんでした。トランスフォーマーライブラリーのtf bertモデルは、PyTorch保存ファイルでロードできます。

ステップ1:次のコマンドを使用して、tfチェックポイントをPytorch保存ファイルに変換します(詳細: https://github.com/ huggingface/transformers/blob/master/docs/source/converting_tensorflow_models.rst

transformers-cli convert --model_type bert\
  --tf_checkpoint=./path/to/checkpoint_file \
  --config=./bert_config.json \
  --pytorch_dump_output=./pytorch_model.bin

ステップ2:ディレクトリ内の以下のファイルを必ず結合してください

  • config.json-bert構成ファイル(bert_config.jsonから名前を変更する必要があります!)
  • pytorch_model.bin-変換したもの
  • vocab.txt-bert vocabファイル

ステップ3:作成したディレクトリからモデルをロードします

model = TFBertModel.from_pretrained('./pretrained_model_dir', from_pt=True)

実際には「from_tf」という引数もあります。ドキュメントによると、これはtfスタイルのチェックポイントで機能するはずですが、機能しません。参照: https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained

3
kaorusss