【关系抽取-R-BERT】加载数据集

【关系抽取-R-BERT】模型结构

【关系抽取-R-BERT】定义训练和验证循环

相关代码

  1. import logging
  2. import os
  3. import numpy as np
  4. import torch
  5. from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
  6. from tqdm import tqdm, trange
  7. from transformers import AdamW, BertConfig, get_linear_schedule_with_warmup
  8. from model import RBERT
  9. from utils import compute_metrics, get_label, write_prediction
  10. logger = logging.getLogger(__name__)
  11. class Trainer(object):
  12. def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
  13. self.args = args
  14. self.train_dataset = train_dataset
  15. self.dev_dataset = dev_dataset
  16. self.test_dataset = test_dataset
  17. self.label_lst = get_label(args)
  18. self.num_labels = len(self.label_lst)
  19. self.config = BertConfig.from_pretrained(
  20. args.model_name_or_path,
  21. num_labels=self.num_labels,
  22. finetuning_task=args.task,
  23. id2label={str(i): label for i, label in enumerate(self.label_lst)},
  24. label2id={label: i for i, label in enumerate(self.label_lst)},
  25. )
  26. self.model = RBERT.from_pretrained(args.model_name_or_path, config=self.config, args=args)
  27. # GPU or CPU
  28. self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
  29. self.model.to(self.device)
  30. def train(self):
  31. train_sampler = RandomSampler(self.train_dataset)
  32. train_dataloader = DataLoader(
  33. self.train_dataset,
  34. sampler=train_sampler,
  35. batch_size=self.args.train_batch_size,
  36. )
  37. if self.args.max_steps > 0:
  38. t_total = self.args.max_steps
  39. self.args.num_train_epochs = (
  40. self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
  41. )
  42. else:
  43. t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
  44. # Prepare optimizer and schedule (linear warmup and decay)
  45. no_decay = ["bias", "LayerNorm.weight"]
  46. optimizer_grouped_parameters = [
  47. {
  48. "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
  49. "weight_decay": self.args.weight_decay,
  50. },
  51. {
  52. "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
  53. "weight_decay": 0.0,
  54. },
  55. ]
  56. optimizer = AdamW(
  57. optimizer_grouped_parameters,
  58. lr=self.args.learning_rate,
  59. eps=self.args.adam_epsilon,
  60. )
  61. scheduler = get_linear_schedule_with_warmup(
  62. optimizer,
  63. num_warmup_steps=self.args.warmup_steps,
  64. num_training_steps=t_total,
  65. )
  66. # Train!
  67. logger.info("***** Running training *****")
  68. logger.info(" Num examples = %d", len(self.train_dataset))
  69. logger.info(" Num Epochs = %d", self.args.num_train_epochs)
  70. logger.info(" Total train batch size = %d", self.args.train_batch_size)
  71. logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
  72. logger.info(" Total optimization steps = %d", t_total)
  73. logger.info(" Logging steps = %d", self.args.logging_steps)
  74. logger.info(" Save steps = %d", self.args.save_steps)
  75. global_step = 0
  76. tr_loss = 0.0
  77. self.model.zero_grad()
  78. train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")
  79. for _ in train_iterator:
  80. epoch_iterator = tqdm(train_dataloader, desc="Iteration")
  81. for step, batch in enumerate(epoch_iterator):
  82. self.model.train()
  83. batch = tuple(t.to(self.device) for t in batch) # GPU or CPU
  84. inputs = {
  85. "input_ids": batch[0],
  86. "attention_mask": batch[1],
  87. "token_type_ids": batch[2],
  88. "labels": batch[3],
  89. "e1_mask": batch[4],
  90. "e2_mask": batch[5],
  91. }
  92. outputs = self.model(**inputs)
  93. loss = outputs[0]
  94. if self.args.gradient_accumulation_steps > 1:
  95. loss = loss / self.args.gradient_accumulation_steps
  96. loss.backward()
  97. tr_loss += loss.item()
  98. if (step + 1) % self.args.gradient_accumulation_steps == 0:
  99. torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
  100. optimizer.step()
  101. scheduler.step() # Update learning rate schedule
  102. self.model.zero_grad()
  103. global_step += 1
  104. if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
  105. self.evaluate("test") # There is no dev set for semeval task
  106. if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
  107. self.save_model()
  108. if 0 < self.args.max_steps < global_step:
  109. epoch_iterator.close()
  110. break
  111. if 0 < self.args.max_steps < global_step:
  112. train_iterator.close()
  113. break
  114. return global_step, tr_loss / global_step
  115. def evaluate(self, mode):
  116. # We use test dataset because semeval doesn't have dev dataset
  117. if mode == "test":
  118. dataset = self.test_dataset
  119. elif mode == "dev":
  120. dataset = self.dev_dataset
  121. else:
  122. raise Exception("Only dev and test dataset available")
  123. eval_sampler = SequentialSampler(dataset)
  124. eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size)
  125. # Eval!
  126. logger.info("***** Running evaluation on %s dataset *****", mode)
  127. logger.info(" Num examples = %d", len(dataset))
  128. logger.info(" Batch size = %d", self.args.eval_batch_size)
  129. eval_loss = 0.0
  130. nb_eval_steps = 0
  131. preds = None
  132. out_label_ids = None
  133. self.model.eval()
  134. for batch in tqdm(eval_dataloader, desc="Evaluating"):
  135. batch = tuple(t.to(self.device) for t in batch)
  136. with torch.no_grad():
  137. inputs = {
  138. "input_ids": batch[0],
  139. "attention_mask": batch[1],
  140. "token_type_ids": batch[2],
  141. "labels": batch[3],
  142. "e1_mask": batch[4],
  143. "e2_mask": batch[5],
  144. }
  145. outputs = self.model(**inputs)
  146. tmp_eval_loss, logits = outputs[:2]
  147. eval_loss += tmp_eval_loss.mean().item()
  148. nb_eval_steps += 1
  149. if preds is None:
  150. preds = logits.detach().cpu().numpy()
  151. out_label_ids = inputs["labels"].detach().cpu().numpy()
  152. else:
  153. preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
  154. out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
  155. eval_loss = eval_loss / nb_eval_steps
  156. results = {"loss": eval_loss}
  157. preds = np.argmax(preds, axis=1)
  158. write_prediction(self.args, os.path.join(self.args.eval_dir, "proposed_answers.txt"), preds)
  159. result = compute_metrics(preds, out_label_ids)
  160. results.update(result)
  161. logger.info("***** Eval results *****")
  162. for key in sorted(results.keys()):
  163. logger.info(" {} = {:.4f}".format(key, results[key]))
  164. return results
  165. def save_model(self):
  166. # Save model checkpoint (Overwrite)
  167. if not os.path.exists(self.args.model_dir):
  168. os.makedirs(self.args.model_dir)
  169. model_to_save = self.model.module if hasattr(self.model, "module") else self.model
  170. model_to_save.save_pretrained(self.args.model_dir)
  171. # Save training arguments together with the trained model
  172. torch.save(self.args, os.path.join(self.args.model_dir, "training_args.bin"))
  173. logger.info("Saving model checkpoint to %s", self.args.model_dir)
  174. def load_model(self):
  175. # Check whether model exists
  176. if not os.path.exists(self.args.model_dir):
  177. raise Exception("Model doesn't exists! Train first!")
  178. self.args = torch.load(os.path.join(self.args.model_dir, "training_args.bin"))
  179. self.model = RBERT.from_pretrained(self.args.model_dir, args=self.args)
  180. self.model.to(self.device)
  181. logger.info("***** Model Loaded *****")

说明

整个代码的流程就是:

  • 定义训练数据;
  • 定义模型;
  • 定义优化器;
  • 如果是训练,将模型切换到训练状态;model.train(),读取数据进行损失计算,反向传播更新参数;
  • 如果是验证或者测试,将模型切换到验证状态:model.eval(),相关计算要用with torch.no_grad()进行包裹,并在里面进行损失的计算、相关评价指标的计算或者预测;

使用到的一些技巧

采样器的使用

在训练的时候,我们使用的是RandomSampler采样器,在验证或者测试的时候,我们使用的是SequentialSampler采样器,关于这些采样器的区别,可以去这里看一下:

https://chenllliang.github.io/2020/02/04/dataloader/

这里简要提一下这两种的区别,训练的时候是打乱数据再进行读取,验证的时候顺序读取数据。

使用梯度累加

核心代码:

  1. if (step + 1) % self.args.gradient_accumulation_steps == 0:
  2. torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
  3. optimizer.step()
  4. scheduler.step() # Update learning rate schedule
  5. self.model.zero_grad()
  6. global_step += 1

梯度累加的作用是当显存不足的时候可以变相的增加batchsize,具体就不作展开了。

不同参数设置权重衰减

核心代码:

  1. no_decay = ["bias", "LayerNorm.weight"]
  2. optimizer_grouped_parameters = [
  3. {
  4. "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
  5. "weight_decay": self.args.weight_decay,
  6. },
  7. {
  8. "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
  9. "weight_decay": 0.0,
  10. },
  11. ]
  12. optimizer = AdamW(
  13. optimizer_grouped_parameters,
  14. lr=self.args.learning_rate,
  15. eps=self.args.adam_epsilon,
  16. )

有的参数是不需要进行权重衰减的,我们可以分别设置。

warmup的使用

核心代码:

  1. scheduler = get_linear_schedule_with_warmup(
  2. optimizer,
  3. num_warmup_steps=self.args.warmup_steps,
  4. num_training_steps=t_total,
  5. )

看一张图:



warmup就是在初始阶段逐渐增大学习率到指定的数值,这么做是为了避免在模型训练的初期的不稳定问题。

代码来源:https://github.com/monologg/R-BERT

【关系抽取-R-BERT】定义训练和验证循环的更多相关文章

  1. 9. 获得图片路径,构造出训练集和验证集,同时构造出相同人脸和不同人脸的测试集,将结果存储为.csv格式 1.random.shuffle(数据清洗) 2.random.sample(从数据集中随机选取2个数据) 3. random.choice(从数据集中抽取一个数据) 4.pickle.dump(将数据集写成.pkl数据)

    1. random.shuffle(dataset) 对数据进行清洗操作 参数说明:dataset表示输入的数据 2.random.sample(dataset, 2) 从dataset数据集中选取2 ...

  2. NLP(二十一)人物关系抽取的一次实战

      去年,笔者写过一篇文章利用关系抽取构建知识图谱的一次尝试,试图用现在的深度学习办法去做开放领域的关系抽取,但是遗憾的是,目前在开放领域的关系抽取,还没有成熟的解决方案和模型.当时的文章仅作为笔者的 ...

  3. 人工智能论文解读精选 | PRGC:一种新的联合关系抽取模型

    NLP论文解读 原创•作者 | 小欣   论文标题:PRGC: Potential Relation and Global Correspondence Based Joint Relational ...

  4. 一次关于关系抽取(RE)综述调研的交流心得

    本文来自于一次交流的的记录,{}内的为个人体会. 基本概念 实事知识:实体-关系-实体的三元组.比如, 知识图谱:大量实时知识组织在一起,可以构建成知识图谱. 关系抽取:由于文本中蕴含大量事实知识,需 ...

  5. 谷歌BERT预训练源码解析(三):训练过程

    目录前言源码解析主函数自定义模型遮蔽词预测下一句预测规范化数据集前言本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨.BERT针对两个任务同 ...

  6. 【例3】设有关系模式R(A, B, C, D, E)与它的函数依赖集F={A→BC, CD→E, B→D, E→A},求R的所有候选键。 解题思路:

    通过分析F发现,其所有的属性A.B.C.D.E都是LR类属性,没有L类.R类.N类属性. 因此,先从这些属性中依次取出一个属性,分别求它们的闭包:=ABCDE,=BD,=C,=D, =ABCDE.由于 ...

  7. 基于BERT预训练的中文命名实体识别TensorFlow实现

    BERT-BiLSMT-CRF-NERTensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuni ...

  8. 【python实现卷积神经网络】定义训练和测试过程

    代码来源:https://github.com/eriklindernoren/ML-From-Scratch 卷积神经网络中卷积层Conv2D(带stride.padding)的具体实现:https ...

  9. 用NVIDIA-NGC对BERT进行训练和微调

    用NVIDIA-NGC对BERT进行训练和微调 Training and Fine-tuning BERT Using NVIDIA NGC 想象一下一个比人类更能理解语言的人工智能程序.想象一下为定 ...

随机推荐

  1. Caffe入门:对于抽象概念的图解分析

    Caffe的几个重要文件 用了这么久Caffe都没好好写过一篇新手入门的博客,最近应实验室小师妹要求,打算写一篇简单.快熟入门的科普文. 利用Caffe进行深度神经网络训练第一步需要搞懂几个重要文件: ...

  2. Gym 101174D Dinner Bet(概率DP)题解

    题意:n个球,两个人每人选C个球作为目标,然后放回.每回合有放回的拿出D个球,如果有目标球,就实现了这个目标,直到至少一个人实现了所有目标游戏结束.问结束回合的期望.误差1e-3以内. 思路:概率DP ...

  3. 在利用手背扫描图像+K因子 对室内温度进行回归预测时碰到的问题

    1. 关于多输入流: 由于本Mission是双输入, 导师要求尽量能使用Inception之诸, 于是输入便成了问题. 思考: 在Github上找到了keras-inceptionV4进行对网络头尾的 ...

  4. Google Developer Days 2019 & GDD

    Google Developer Days 2019 2019 Google 开发者大会 GDD Google Developer Days https://events.google.cn/intl ...

  5. css & background-image & full page width & background-size

    css & background-image & full page width & background-size https://css-tricks.com/perfec ...

  6. HTTP/3 protocol

    HTTP/3 protocol https://caniuse.com/#feat=http3 HTTP/3 H3 https://en.wikipedia.org/wiki/HTTP/3 QUIC ...

  7. uniapp 扫二维码跳转

    在h5和wxapp中 生成qrcode的组件 https://ext.dcloud.net.cn/plugin?id=39 wx小程序扫二位码文档 生成链接时 computed: { ...mapSt ...

  8. 一些小Tip

    导语 个人感悟,持续更新中... 正文 无论NIO还是AIO,都没有在数据传输过程(tcp/udp)作革命性的创新.他们在传输过程的效率和传统BIO是一样的,还是会产生阻塞(网络延迟,Socket缓冲 ...

  9. Java自学no.1———带你初步认识java

    什么是Java Java语言是美国Sun公司(Stanford University Network),在1995年推出的高级的编程语言.所谓编程语言,是 计算机的语言,人们可以使用编程语言对计算机下 ...

  10. dotnet core TargetFramework 解析顺序测试

    dotnet core TargetFramework 解析顺序测试 Intro 现在 dotnet 的 TargetFramework 越来越多,抛开 .NET Framework 不谈,如果一个类 ...