
  1. import torch
  2. from configs.base import config
  3. from model.modeling_albert import BertConfig, BertModel
  4. from model.tokenization_bert import BertTokenizer
  5. from keras.preprocessing.sequence import pad_sequences
  6. from torch.utils.data import TensorDataset, DataLoader, RandomSampler
  8. import os
  10. device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
  11. MAX_LEN = 10
  12. if __name__ == '__main__':
  13. bert_config = BertConfig.from_pretrained(str(config['albert_config_path']), share_type='all')
  14. base_path = os.getcwd()
  15. VOCAB = base_path + '/configs/vocab.txt' # your path for model and vocab
  16. tokenizer = BertTokenizer.from_pretrained(VOCAB)
  18. # encoder text
  19. tag2idx={'[SOS]':101, '[EOS]':102, '[PAD]':0, 'B_LOC':1, 'I_LOC':2, 'O':3}
  20. sentences = ['我是中华人民共和国国民', '我爱祖国']
  21. tags = ['O O B_LOC I_LOC I_LOC I_LOC I_LOC I_LOC O O', 'O O O O']
  23. tokenized_text = [tokenizer.tokenize(sent) for sent in sentences]
  24. #利用pad_sequence对序列长度进行截断和padding
  25. input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_text], #没法一条一条处理,只能2-d的数据,即多于一条样本,但是如果全部加载到内存是不是会爆
  26. maxlen=MAX_LEN-2,
  27. truncating='post',
  28. padding='post',
  29. value=0)
  31. tag_ids = pad_sequences([[tag2idx.get(tok) for tok in tag.split()] for tag in tags],
  32. maxlen=MAX_LEN-2,
  33. padding="post",
  34. truncating="post",
  35. value=0)
  37. #bert中的句子前后需要加入[CLS]:101和[SEP]:102
  38. input_ids_cls_sep = []
  39. for input_id in input_ids:
  40. linelist = []
  41. linelist.append(101)
  42. flag = True
  43. for tag in input_id:
  44. if tag > 0:
  45. linelist.append(tag)
  46. elif tag == 0 and flag:
  47. linelist.append(102)
  48. linelist.append(tag)
  49. flag = False
  50. else:
  51. linelist.append(tag)
  52. if tag > 0:
  53. linelist.append(102)
  54. input_ids_cls_sep.append(linelist)
  56. tag_ids_cls_sep = []
  57. for tag_id in tag_ids:
  58. linelist = []
  59. linelist.append(101)
  60. flag = True
  61. for tag in tag_id:
  62. if tag > 0:
  63. linelist.append(tag)
  64. elif tag == 0 and flag:
  65. linelist.append(102)
  66. linelist.append(tag)
  67. flag = False
  68. else:
  69. linelist.append(tag)
  70. if tag > 0:
  71. linelist.append(102)
  72. tag_ids_cls_sep.append(linelist)
  74. attention_masks = [[int(tok > 0) for tok in line] for line in input_ids_cls_sep]
  76. print('---------------------------')
  77. print('input_ids:{}'.format(input_ids_cls_sep))
  78. print('tag_ids:{}'.format(tag_ids_cls_sep))
  79. print('attention_masks:{}'.format(attention_masks))
  81. # input_ids = torch.tensor([tokenizer.encode('我 是 中 华 人 民 共 和 国 国 民', add_special_tokens=True)]) #为True则句子首尾添加[CLS]和[SEP]
  82. # print('input_ids:{}, size:{}'.format(input_ids, len(input_ids)))
  83. # print('attention_masks:{}, size:{}'.format(attention_masks, len(attention_masks)))
  85. inputs_tensor = torch.tensor(input_ids_cls_sep)
  86. tags_tensor = torch.tensor(tag_ids_cls_sep)
  87. masks_tensor = torch.tensor(attention_masks)
  89. train_data = TensorDataset(inputs_tensor, tags_tensor, masks_tensor)
  90. train_sampler = RandomSampler(train_data)
  91. train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=2)
  93. model = BertModel.from_pretrained(config['bert_dir'],config=bert_config)
  94. model.to(device)
  95. model.eval()
  96. with torch.no_grad():
  97. '''
  98. note:
  99. 一.
  100. 如果设置:"output_hidden_states":"True"和"output_attentions":"True"
  101. 输出的是: 所有层的 sequence_output, pooled_output, (hidden_states), (attentions)
  102. 则 all_hidden_states, all_attentions = model(input_ids)[-2:]
  104. 二.
  105. 如果没有设置:output_hidden_states和output_attentions
  106. 输出的是:最后一层 --> (output_hidden_states, output_attentions)
  107. '''
  108. for index, batch in enumerate(train_dataloader):
  109. batch = tuple(t.to(device) for t in batch)
  110. b_input_ids, b_input_mask, b_labels = batch
  111. last_hidden_state = model(input_ids = b_input_ids,attention_mask = b_input_mask)
  112. print(len(last_hidden_state))
  113. all_hidden_states, all_attentions = last_hidden_state[-2:] #这里获取所有层的hidden_satates以及attentions
  114. print(all_hidden_states[-2].shape)#倒数第二层hidden_statesshape


input_ids:[[101, 2769, 3221, 704, 1290, 782, 3696, 1066, 1469, 102], [101, 2769, 4263, 4862, 1744, 102, 0, 0, 0, 0]]
tag_ids:[[101, 3, 3, 1, 2, 2, 2, 2, 2, 102], [101, 3, 3, 3, 3, 102, 0, 0, 0, 0]]
attention_masks:[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]
torch.Size([2, 10, 768])
tensor([[[-1.1074, -0.0047,  0.4608,  ..., -0.1816, -0.6379,  0.2295],
         [-0.1930, -0.4629,  0.4127,  ..., -0.5227, -0.2401, -0.1014],
         [ 0.2682, -0.6617,  0.2744,  ..., -0.6689, -0.4464,  0.1460],
         [-0.1723, -0.7065,  0.4111,  ..., -0.6570, -0.3490, -0.5541],
         [-0.2028, -0.7025,  0.3954,  ..., -0.6566, -0.3653, -0.5655],
         [-0.2026, -0.6831,  0.3778,  ..., -0.6461, -0.3654, -0.5523]],

[[-1.3166, -0.0052,  0.6554,  ..., -0.2217, -0.5685,  0.4270],
         [-0.2755, -0.3229,  0.4831,  ..., -0.5839, -0.1757, -0.1054],
         [-1.4941, -0.1436,  0.8720,  ..., -0.8316, -0.5213, -0.3893],
         [-0.7022, -0.4104,  0.5598,  ..., -0.6664, -0.1627, -0.6270],
         [-0.7389, -0.2896,  0.6083,  ..., -0.7895, -0.2251, -0.4088],
         [-0.0351, -0.9981,  0.0660,  ..., -0.4606,  0.4439, -0.6745]]])


