最近一直在写leetcode,前段时间学的ML相关东西都有点忘了,近期通过Kaggle进行一下快速复习,顺便学习一些新东西.
这次的练习1主要是Kaggle上面的Natural Language Processing with Disaster Tweets
HuggingFace-Transfomers
之前见到过一次但没有动手来写,印象里是一个很functional的工具,还有pipline可以直接cmd使用,这次去看了下文档动手写了下.
首先看下数据长啥样 1
2
3train_url = '/kaggle/input/nlp-getting-started/train.csv'
train_data = pd.read_csv(train_url)
train_data.head()
然后需要注意使用transformers需要wandb的token,addon里面自己添加key即可
1
2
3
4
5
6
7!pip install transformers datasets evaluate accelerate wandb
import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_token = user_secrets.get_secret("wandb_token")
! wandb login $wandb_token
然后我们使用distilbert/distilbert-base-uncased模型,这里搜索了一下distilbert,这是一种通过softmax
with temperature对bert做出性能优化的模型,具体可以看这里DistilBert.
1 | from transformers import AutoTokenizer |
需要注意这里有个坑,需要rename一下column_name否则后续train会报错.
要求文字部分为text且target为label.
最后设置参数,定义trainer即可 1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import DataCollatorWithPadding
model = AutoModelForSequenceClassification.from_pretrained('distilbert/distilbert-base-uncased', num_labels=2)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_args = TrainingArguments(
output_dir='./disaster_NLP',
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy='epoch',
save_strategy='epoch',
load_best_model_at_end=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_valid,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics
)
trainer.train()
训练完存到submit提交即可 1
2
3
4
5
6
7preds = trainer.predict(tokenized_test)
preds = np.argmax(preds[:3][0],axis=1)
submission_path = '/kaggle/input/nlp-getting-started/sample_submission.csv'
submission_data = pd.read_csv(submission_path)
submission_data['target'] = preds
submission_data.to_csv('submission.csv', index=False)
纯文档实现,调参之后性能可能可以优化这里不做讨论了.