乡下人产国偷v产偷v自拍,国产午夜片在线观看,婷婷成人亚洲综合国产麻豆,久久综合给合久久狠狠狠9

  • <output id="e9wm2"></output>
    <s id="e9wm2"><nobr id="e9wm2"><ins id="e9wm2"></ins></nobr></s>

    • 分享

      超詳細(xì)中文注釋的GPT2新聞標(biāo)題生成項(xiàng)目

       520jefferson 2021-01-14

      筆者開(kāi)源了一個(gè)帶有超詳細(xì)中文注釋的GPT2新聞標(biāo)題生成項(xiàng)目。

      該項(xiàng)目參考了GPT2-Chinese、GPT2-chitchat、CDial-GPT、GPT2等多個(gè)GPT2開(kāi)源項(xiàng)目(感謝大佬們的開(kāi)源),并根據(jù)自己的理解,將代碼進(jìn)行重構(gòu),添加詳細(xì)注釋,希望可以幫助到有需要的同學(xué)。

      項(xiàng)目是基于HuggingFace的transformers實(shí)現(xiàn)GPT2模型代碼進(jìn)行修改、訓(xùn)練及測(cè)試。并且通過(guò)Flask框架搭建了一個(gè)Web服務(wù),將新聞標(biāo)題生成模型進(jìn)行工程化,可以通過(guò)頁(yè)面,可視化地體驗(yàn)新聞標(biāo)題生成效果。

      該項(xiàng)目的目的是帶領(lǐng)大家走一遍GPT2生成模型的訓(xùn)練、測(cè)試及部署全部流程。

      項(xiàng)目地址:https://github.com/liucongg/GPT2-NewsTitle

      本文主要是對(duì)項(xiàng)目中的代碼進(jìn)行講解,主要從數(shù)據(jù)預(yù)處理、數(shù)據(jù)類實(shí)現(xiàn)、模型代碼實(shí)現(xiàn)、模型訓(xùn)練、模型測(cè)試和模型上線,六個(gè)部分進(jìn)行介紹,如下。

      數(shù)據(jù)預(yù)處理

      數(shù)據(jù)來(lái)源于新浪微博,由He Zhengfang大佬整理,詳細(xì)鏈接如下:https://www.jianshu.com/p/8f52352f0748?tdsourcetag=s_pcqq_aiomsg。

      由于數(shù)據(jù)來(lái)自微博,在標(biāo)題中常常帶有“話題”、“表情”標(biāo)記,在正文中常常帶有“HTML”標(biāo)記,如下:

      Title:
      2014#福布斯中國(guó)名人榜#:她再奪冠[威武]
      Content:
      為什么我們要工作?聽(tīng)演講者Barry Schwartz告訴你工作的另一個(gè)重要意義。非常有深度的一個(gè)演講,值得一看!http:///RqzKvtn 轉(zhuǎn)發(fā)學(xué)習(xí),給自己的工作加油打氣吧![good]

      因此需要對(duì)數(shù)據(jù)進(jìn)行清洗,具體如下:

      (1)對(duì)標(biāo)題清洗時(shí),會(huì)去除“##”符號(hào)(一般為微博數(shù)據(jù)的話題標(biāo)記)、去除“[]”中間的文字(一般為微博數(shù)據(jù)中的表情)、合并標(biāo)題中過(guò)多的空格

      def clean_weibo_title(title: str):
      '''
      對(duì)微博數(shù)據(jù)中的標(biāo)題內(nèi)容(待生成)進(jìn)行清洗
      Args:
      title: 標(biāo)題
      Returns:
      '''
      # 去除##符號(hào)(一般為微博數(shù)據(jù)的話題標(biāo)記)
      title = re.sub(r'#', '', title)
      # 去除[]中間的文字(一般為微博數(shù)據(jù)中的表情)
      title = re.sub(r'(\[{1,2})(.*?)(\]{1,2})', '', title)
      # 合并標(biāo)題中過(guò)多的空格
      title = re.sub(r'\s+', ' ', title)
      return title

      (2)對(duì)正文清洗時(shí),會(huì)去除網(wǎng)址、合并正文中過(guò)多的空格、去除“\u200b”字符

      def clean_weibo_content(content: str):
      '''
      對(duì)微博數(shù)據(jù)中的文本內(nèi)容進(jìn)行清洗
      Args:
      content: 文本
      Returns:
      '''
      # 去除網(wǎng)址
      content = re.sub(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b', '', content)
      # 合并正文中過(guò)多的空格
      content = re.sub(r'\s+', ' ', content)
      # 去除\u200b字符
      content = content.replace('\u200b', '')
      return content

      (3)對(duì)清洗后的數(shù)據(jù)進(jìn)行整合,去除重復(fù)數(shù)據(jù)、正文內(nèi)容字?jǐn)?shù)小于100的數(shù)據(jù)和標(biāo)題內(nèi)容字?jǐn)?shù)小于2的數(shù)據(jù);并且拆分訓(xùn)練集和測(cè)試集。

      def build_news_data(content_path, title_path, train_save_path, test_save_path):
      '''
      對(duì)微博數(shù)據(jù)進(jìn)行清洗,構(gòu)建訓(xùn)練集和測(cè)試集
      Args:
      content_path: 正文內(nèi)容文件路徑
      title_path: 標(biāo)題內(nèi)容文件路徑
      train_save_path: 訓(xùn)練集文件路徑
      test_save_path: 測(cè)試集文件路徑
      Returns:
      '''
      # 打開(kāi)文件,并將其zip成一個(gè)文件
      content_data = open(content_path, 'r', encoding='utf-8')
      title_data = open(title_path, 'r', encoding='utf-8')
      data = zip(content_data.readlines(), title_data.readlines())
      # 使用多進(jìn)程處理數(shù)據(jù)
      threads = min(8, cpu_count())
      with Pool(threads) as p:
      annoate_ = partial(clean_data)
      data = list(tqdm(p.imap(annoate_, data, chunksize=8),
      desc='build data'
      )
      )
      # 對(duì)數(shù)據(jù)進(jìn)行過(guò)濾,去除重復(fù)數(shù)據(jù)、正文內(nèi)容字長(zhǎng)小于100的數(shù)據(jù)和標(biāo)題內(nèi)容字長(zhǎng)小于100的數(shù)據(jù)
      data_set = set()
      data_new = []
      for d in data:
      if d['content'] in data_set or len(d['content']) < 100 or len(d['title']) < 2:
      continue
      else:
      data_set.add(d['content'])
      data_new.append(d)
      # 拆分?jǐn)?shù)據(jù),構(gòu)建訓(xùn)練集和測(cè)試集
      random.shuffle(data_new)
      train_data = data_new[:-3000]
      test_data = data_new[-3000:]
      fin = open(train_save_path, 'w', encoding='utf-8')
      fin.write(json.dumps(train_data, indent=4, ensure_ascii=False))
      fin.close()
      fin = open(test_save_path, 'w', encoding='utf-8')
      fin.write(json.dumps(test_data, indent=4, ensure_ascii=False))
      fin.close()

      詳細(xì)代碼見(jiàn)Github項(xiàng)目的data_helper.py文件。

      數(shù)據(jù)類實(shí)現(xiàn)

      數(shù)據(jù)類的作用是將文本數(shù)據(jù)轉(zhuǎn)換成模型可以使用的索引數(shù)據(jù),并預(yù)先存儲(chǔ)下來(lái)。避免模型每訓(xùn)練一步,都進(jìn)行無(wú)效的數(shù)據(jù)轉(zhuǎn)換操作。

      (1)判斷是否存在緩存文件,如果存在,則直接加載;否則重新將文本數(shù)據(jù)轉(zhuǎn)換為索引數(shù)據(jù),并存為緩存。

      if os.path.exists(cached_feature_file) and not is_overwrite:
      logger.info('已經(jīng)存在緩存文件{},直接加載'.format(cached_feature_file))
      self.data_set = torch.load(cached_feature_file)['data_set']
      # 如果緩存數(shù)據(jù)不存在,則對(duì)原始數(shù)據(jù)進(jìn)行數(shù)據(jù)處理操作,并將處理后的數(shù)據(jù)存成緩存文件
      else:
      logger.info('不存在緩存文件{},進(jìn)行數(shù)據(jù)預(yù)處理操作'.format(cached_feature_file))
      self.data_set = self.load_data(path_file)
      logger.info('數(shù)據(jù)預(yù)處理操作完成,將處理后的數(shù)據(jù)存到{}中,作為緩存文件'.format(cached_feature_file))
      torch.save({'data_set': self.data_set}, cached_feature_file)

      (2)將文本數(shù)據(jù)轉(zhuǎn)換為索引數(shù)據(jù)的函數(shù)

      def convert_feature(self, sample):
      '''
      數(shù)據(jù)處理函數(shù)
      Args:
      sample: 一個(gè)字典,包含新聞的正文和新聞的標(biāo)題,格式為{'content': content, 'title': title}
      Returns:
      '''
      input_ids = []
      token_type_ids = []
      # 對(duì)新聞?wù)倪M(jìn)行tokenizer.tokenize分詞
      content_tokens = self.tokenizer.tokenize(sample['content'])
      # 對(duì)新聞標(biāo)題進(jìn)行tokenizer.tokenize分詞,注意tokenizer中已經(jīng)將[Space]作為一個(gè)分隔符,不會(huì)切割成多個(gè)字符
      title_tokens = self.tokenizer.tokenize(sample['title'].replace(' ', '[Space]'))
      # 判斷如果正文過(guò)長(zhǎng),進(jìn)行截?cái)?br> if len(content_tokens) > self.max_len - len(title_tokens) - 3:
      content_tokens = content_tokens[:self.max_len - len(title_tokens) - 3]
      # 生成模型所需的input_ids和token_type_ids
      input_ids.append(self.tokenizer.cls_token_id)
      token_type_ids.append(self.content_id)
      input_ids.extend(self.tokenizer.convert_tokens_to_ids(content_tokens))
      token_type_ids.extend([self.content_id] * len(content_tokens))
      input_ids.append(self.tokenizer.sep_token_id)
      token_type_ids.append(self.content_id)
      input_ids.extend(self.tokenizer.convert_tokens_to_ids(title_tokens))
      token_type_ids.extend([self.title_id] * len(title_tokens))
      input_ids.append(self.tokenizer.sep_token_id)
      token_type_ids.append(self.title_id)
      # 判斷input_ids與token_type_ids長(zhǎng)度是否一致
      assert len(input_ids) == len(token_type_ids)
      # 判斷input_ids長(zhǎng)度是否小于等于最大長(zhǎng)度
      assert len(input_ids) <= self.max_len
      return input_ids, token_type_ids

      詳細(xì)代碼見(jiàn)Github項(xiàng)目的data_set.py文件。

      模型代碼實(shí)現(xiàn)

      模型部分,主要對(duì)transformers包中GPT2LMHeadModel類進(jìn)行重寫(xiě),修改計(jì)算loss部分,只計(jì)算預(yù)測(cè)title部分的loss。

      模型的輸入由word embedding、segment embedding和position embedding三部分組成,具體如下圖所示:

      為什么需要加segment embedding?
      為了更好地區(qū)分Content和Title,并且根據(jù)token type id可以僅計(jì)算title部分的損失值。

      def forward(self, input_ids=None, past=None, token_type_ids=None, labels=None, title_id=None):
      '''
      前向函數(shù),計(jì)算GPT2預(yù)測(cè)結(jié)果值
      Args:
      input_ids: 輸入序列在詞表中的索引序列,size:[batch_size, sequence_length]
      past: 包含由模型預(yù)先計(jì)算好的隱藏狀態(tài),一般使用在預(yù)測(cè)階段,用于加速順序解碼,防止重復(fù)計(jì)算前面計(jì)算過(guò)的token
      token_type_ids: 用于區(qū)分輸入序列中content和title的分隔符序列,size:[batch_size, sequence_length]
      labels: 標(biāo)簽序列,size:[batch_size, sequence_length],一般情況下,與input_ids相同
      title_id: title部分分隔符的id
      Returns:
      '''
      # 獲取GPT2模型的輸出結(jié)果
      transformer_outputs = self.transformer(input_ids, past=past, token_type_ids=token_type_ids)
      # 獲取GPT2模型的最后一層的隱層節(jié)點(diǎn)狀態(tài),size:[batch_size, sequence_length, config.n_embd]
      hidden_states = transformer_outputs[0]
      # 預(yù)測(cè)隱層節(jié)點(diǎn)狀態(tài)中的每一個(gè)token的下一個(gè)token,size:[batch_size, sequence_length, config.vocab_size]
      lm_logits = self.lm_head(hidden_states)
      # 拼接輸出結(jié)果
      outputs = (lm_logits,) + transformer_outputs[1:]
      # 如果labels不為None時(shí),計(jì)算損失值loss,并拼接到輸出結(jié)果中
      if labels is not None:
      # 計(jì)算loss時(shí),title_id不可以為None,因?yàn)樾枰猼itle_id找到title的部分
      if title_id is None or token_type_ids is None:
      raise Exception('當(dāng)labels不為None時(shí), title_id和token_type_ids均不可以為None。')
      # 獲取mask值,如果token_type_ids中等于title_id的部分需要計(jì)算loss,標(biāo)記為1;否則為0。
      # size:[batch_size, sequence_length]
      mask = (token_type_ids == title_id).long()
      # 獲取新的標(biāo)簽,size:[batch_size, sequence_length]
      labels = labels * mask
      # 對(duì)預(yù)測(cè)結(jié)果和標(biāo)簽進(jìn)行偏移操作
      # GPT2的生成機(jī)制為通過(guò)前面的token,預(yù)測(cè)下一個(gè)token;并且labels與input_ids相同,
      # 因此input_ids中的第一個(gè)token的預(yù)測(cè)結(jié)果,實(shí)際上是標(biāo)簽中的第二個(gè)token,以此類推,最終僅計(jì)算sequence_length-1個(gè)token的loss
      shift_logits = lm_logits[..., :-1, :].contiguous()
      shift_labels = labels[..., 1:].contiguous()

      # 定義損失函數(shù)CrossEntropyLoss,并且設(shè)置忽略計(jì)算loss的索引,以及返回loss的形式
      # 忽略shift_labels中為0的loss,也就是僅計(jì)算title部分的損失值
      # 對(duì)loss的計(jì)算方式設(shè)為sum,由于我們僅計(jì)算了itle部分的損失值,如果使用mean,會(huì)使loss變?。▽?shí)際除的是sequence_length-1,不是title部分的真實(shí)長(zhǎng)度)
      loss_fct = CrossEntropyLoss(ignore_index=0, reduction='sum')
      loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
      # 獲取title部分的真實(shí)長(zhǎng)度,并計(jì)算真實(shí)loss
      num = shift_labels.ne(0).long().sum().item()
      loss = loss / num
      outputs = (loss,) + outputs
      return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)

      詳細(xì)代碼見(jiàn)Github項(xiàng)目的model.py文件。

      模型訓(xùn)練

      模型訓(xùn)練參數(shù)如下圖所示:

      模型訓(xùn)練執(zhí)行代碼如下:

      python3 train.py

      python3 train.py --output_dir output_dir/(自定義保存模型路徑)

      模型訓(xùn)練文件主要由以下幾個(gè)函數(shù)組成:(1)設(shè)置訓(xùn)練模型所需參數(shù)函數(shù)set_args;(2)訓(xùn)練模型函數(shù)train;(3)對(duì)測(cè)試數(shù)據(jù)集進(jìn)行模型測(cè)試evaluate;(4)主函數(shù)main。

      詳細(xì)代碼見(jiàn)Github項(xiàng)目的train.py文件。

      值得注意的是,在實(shí)例化tokenizer時(shí),一定要使用tokenizer.add_tokens('[Space]', special_tokens=True),目的是為了將[Space]作為一個(gè)切分整體,例如:'我愛(ài)[Space]北京天安門。',使用原始tokenizer分詞結(jié)果為'['我', '愛(ài)', '[', 'Space', ']', '北', '京', '天', '安','門','。']';增加切分符號(hào)后的結(jié)果為'['我', '愛(ài)', '[Space]', '北', '京', '天', '安','門','。']'。

      模型測(cè)試

      模型測(cè)試部分,主要是通過(guò)不同的解碼策略,對(duì)已經(jīng)訓(xùn)練好的模型進(jìn)行單個(gè)樣本的預(yù)測(cè)。

      (1)top_k或top_p解碼策略,僅保留top_k個(gè)或累積概率到達(dá)top_p的標(biāo)記,其他標(biāo)記設(shè)為filter_value,后續(xù)在選取標(biāo)記的過(guò)程中會(huì)取不到值設(shè)為無(wú)窮小。

      def top_k_top_p_filtering(logits, top_k, top_p, filter_value=-float('Inf')):
      '''
      top_k或top_p解碼策略,僅保留top_k個(gè)或累積概率到達(dá)top_p的標(biāo)記,其他標(biāo)記設(shè)為filter_value,后續(xù)在選取標(biāo)記的過(guò)程中會(huì)取不到值設(shè)為無(wú)窮小。
      Args:
      logits: 預(yù)測(cè)結(jié)果,即預(yù)測(cè)成為詞典中每個(gè)詞的分?jǐn)?shù)
      top_k: 只保留概率最高的top_k個(gè)標(biāo)記
      top_p: 只保留概率累積達(dá)到top_p的標(biāo)記
      filter_value: 過(guò)濾標(biāo)記值
      Returns:
      '''
      # logits的維度必須為2,即size:[batch_size, vocab_size]
      assert logits.dim() == 2
      # 獲取top_k和字典大小中較小的一個(gè),也就是說(shuō),如果top_k大于字典大小,則取字典大小個(gè)標(biāo)記
      top_k = min(top_k, logits[0].size(-1))
      # 如果top_k不為0,則將在logits中保留top_k個(gè)標(biāo)記
      if top_k > 0:
      # 由于有batch_size個(gè)預(yù)測(cè)結(jié)果,因此對(duì)其遍歷,選取每個(gè)預(yù)測(cè)結(jié)果的top_k標(biāo)記
      for logit in logits:
      indices_to_remove = logit < torch.topk(logit, top_k)[0][..., -1, None]
      logit[indices_to_remove] = filter_value
      # 如果top_p不為0,則將在logits中保留概率值累積達(dá)到top_p的標(biāo)記
      if top_p > 0.0:
      # 對(duì)logits進(jìn)行遞減排序
      sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
      # 對(duì)排序后的結(jié)果使用softmax歸一化,再獲取累積概率序列
      # 例如:原始序列[0.1, 0.2, 0.3, 0.4],則變?yōu)椋篬0.1, 0.3, 0.6, 1.0]
      cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
      # 刪除累積概率高于top_p的標(biāo)記
      sorted_indices_to_remove = cumulative_probs > top_p
      # 將索引向右移動(dòng),使第一個(gè)標(biāo)記也保持在top_p之上
      sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
      sorted_indices_to_remove[..., 0] = 0
      for index, logit in enumerate(logits):
      # 由于有batch_size個(gè)預(yù)測(cè)結(jié)果,因此對(duì)其遍歷,選取每個(gè)預(yù)測(cè)結(jié)果的累積概率達(dá)到top_p的標(biāo)記
      indices_to_remove = sorted_indices[index][sorted_indices_to_remove[index]]
      logit[indices_to_remove] = filter_value
      return logits

      (2)對(duì)單個(gè)樣本進(jìn)行預(yù)測(cè)

      def predict_one_sample(model, tokenizer, device, args, content):
      '''
      對(duì)單個(gè)樣本進(jìn)行預(yù)測(cè)
      Args:
      model: 模型
      tokenizer: 分詞器
      device: 設(shè)備信息
      args: 配置項(xiàng)信息
      content: 新聞?wù)?br> Returns:
      '''
      # 對(duì)新聞?wù)倪M(jìn)行預(yù)處理,并判斷如果超長(zhǎng)則進(jìn)行截?cái)?br> content_tokens = tokenizer.tokenize(content)
      if len(content_tokens) > args.max_len - 3 - args.generate_max_len:
      content_tokens = content_tokens[:args.max_len - 3 - args.generate_max_len]
      # 獲取content_id、title_id、unk_id、sep_id值
      content_id = tokenizer.convert_tokens_to_ids('[Content]')
      title_id = tokenizer.convert_tokens_to_ids('[Title]')
      unk_id = tokenizer.convert_tokens_to_ids('[UNK]')
      sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
      # 將tokens索引化,變成模型所需格式
      content_tokens = ['[CLS]'] + content_tokens + ['[SEP]']
      input_ids = tokenizer.convert_tokens_to_ids(content_tokens)
      # 將input_ids和token_type_ids進(jìn)行擴(kuò)充,擴(kuò)充到需要預(yù)測(cè)標(biāo)題的個(gè)數(shù),即batch_size
      input_ids = [copy.deepcopy(input_ids) for _ in range(args.batch_size)]
      token_type_ids = [[content_id] * len(content_tokens) for _ in range(args.batch_size)]
      # 將input_ids和token_type_ids變成tensor
      input_tensors = torch.tensor(input_ids).long().to(device)
      token_type_tensors = torch.tensor(token_type_ids).long().to(device)
      next_token_type = torch.tensor([[title_id] for _ in range(args.batch_size)]).long().to(device)
      # 用于存放每一步解碼的結(jié)果
      generated = []
      # 用于存放,完成解碼序列的序號(hào)
      finish_set = set()
      with torch.no_grad():
      # 遍歷生成標(biāo)題最大長(zhǎng)度
      for _ in range(args.generate_max_len):
      outputs = model(input_ids=input_tensors, token_type_ids=token_type_tensors)
      # 獲取預(yù)測(cè)結(jié)果序列的最后一個(gè)標(biāo)記,next_token_logits size:[batch_size, vocab_size]
      next_token_logits = outputs[0][:, -1, :]
      # 對(duì)batch_size進(jìn)行遍歷,將詞表中出現(xiàn)在序列中的詞的概率進(jìn)行懲罰
      for index in range(args.batch_size):
      for token_id in set([token_ids[index] for token_ids in generated]):
      next_token_logits[index][token_id] /= args.repetition_penalty
      # 對(duì)batch_size進(jìn)行遍歷,將詞表中的UNK的值設(shè)為無(wú)窮小
      for next_token_logit in next_token_logits:
      next_token_logit[unk_id] = -float('Inf')
      # 使用top_k_top_p_filtering函數(shù),按照top_k和top_p的值,對(duì)預(yù)測(cè)結(jié)果進(jìn)行篩選
      filter_logits = top_k_top_p_filtering(next_token_logits, top_k=args.top_k, top_p=args.top_p)
      # 對(duì)filter_logits的每一行做一次取值,輸出結(jié)果是每一次取值時(shí)filter_logits對(duì)應(yīng)行的下標(biāo),即詞表位置(詞的id)
      # filter_logits中的越大的值,越容易被選中
      next_tokens = torch.multinomial(F.softmax(filter_logits, dim=-1), num_samples=1)
      # 判斷如果哪個(gè)序列的預(yù)測(cè)標(biāo)記為sep_id時(shí),則加入到finish_set
      for index, token_id in enumerate(next_tokens[:, 0]):
      if token_id == sep_id:
      finish_set.add(index)
      # 判斷,如果finish_set包含全部的序列序號(hào),則停止預(yù)測(cè);否則繼續(xù)預(yù)測(cè)
      finish_flag = True
      for index in range(args.batch_size):
      if index not in finish_set:
      finish_flag = False
      break
      if finish_flag:
      break
      # 將預(yù)測(cè)標(biāo)記添加到generated中
      generated.append([token.item() for token in next_tokens[:, 0]])
      # 將預(yù)測(cè)結(jié)果拼接到input_tensors和token_type_tensors上,繼續(xù)下一次預(yù)測(cè)
      input_tensors = torch.cat((input_tensors, next_tokens), dim=-1)
      token_type_tensors = torch.cat((token_type_tensors, next_token_type), dim=-1)
      # 用于存儲(chǔ)預(yù)測(cè)結(jié)果
      candidate_responses = []
      # 對(duì)batch_size進(jìn)行遍歷,并將token_id變成對(duì)應(yīng)漢字
      for index in range(args.batch_size):
      responses = []
      for token_index in range(len(generated)):
      # 判斷,當(dāng)出現(xiàn)sep_id時(shí),停止在該序列中添加token
      if generated[token_index][index] != sep_id:
      responses.append(generated[token_index][index])
      else:
      break
      # 將token_id序列變成漢字序列,去除'##',并將[Space]替換成空格
      candidate_responses.append(
      ''.join(tokenizer.convert_ids_to_tokens(responses)).replace('##', '').replace('[Space]', ' '))
      return candidate_responses

      詳細(xì)代碼見(jiàn)Github項(xiàng)目的generate_title.py文件。

      測(cè)試結(jié)果如下:

      從測(cè)試集中抽一篇
      content:
      今日,中國(guó)三條重要高鐵干線——蘭新高鐵、貴廣鐵路和南廣鐵路將開(kāi)通運(yùn)營(yíng)。其中蘭新高鐵是中國(guó)首條高原高鐵,全長(zhǎng)1776公里,最高票價(jià)658元。貴廣鐵路最貴車票320元,南廣鐵路最貴車票206.5元,這兩條線路大大縮短西南與各地的時(shí)空距離。出行更方便了!中國(guó)“高鐵版圖”再擴(kuò)容 三條重要高鐵今日開(kāi)通
      title:
      生成的第1個(gè)標(biāo)題為:中國(guó)“高鐵版圖”再擴(kuò)容 三條重要高鐵今日開(kāi)通
      生成的第2個(gè)標(biāo)題為:貴廣鐵路最高鐵版圖
      生成的第3個(gè)標(biāo)題為:出行更方便了!中國(guó)“高鐵版圖”再擴(kuò)容三條重要高鐵今日開(kāi)通

      模型上線

      通過(guò)Flask框架搭建了一個(gè)Web服務(wù),將新聞?wù)赡P瓦M(jìn)行工程化,可以通過(guò)頁(yè)面可視化地體驗(yàn)新聞?wù)尚Ч?/p>

      詳細(xì)代碼見(jiàn)Github項(xiàng)目的http_server.py文件。

      并且在我之前文章中,詳細(xì)介紹過(guò)如何使用Flask框架搭建Web服務(wù),見(jiàn):https://zhuanlan.zhihu.com/p/143678340

      https://zhuanlan.zhihu.com/p/148224626

      啟動(dòng)服務(wù)命令:
      python3 http_server.py

      python3 http_server.py --http_id '0.0.0.0' --port 5555
      本地測(cè)試直接使用'127.0.0.1:5555/news-title-generate',如果給他人訪問(wèn),只需將'127.0.0.1'替換成的電腦的IP地址即可。
      初始頁(yè)面如下圖所示:
      輸入新聞?wù)暮?,點(diǎn)擊“一鍵生成”,可以獲取到生成的新聞標(biāo)題,如下圖所示:

      后期工作

      可能會(huì)將清華新聞數(shù)據(jù)、搜狗新聞數(shù)據(jù)等新聞數(shù)據(jù)集進(jìn)行整理清洗,構(gòu)建一個(gè)較完善的新聞?wù)獢?shù)據(jù)集。

      可能會(huì)使用新聞數(shù)據(jù)訓(xùn)練一個(gè)小的GPT2預(yù)訓(xùn)練模型。
      可能會(huì)對(duì)已上傳的新聞標(biāo)題生成模型進(jìn)行更新,訓(xùn)練一個(gè)效果較好的模型。

      總結(jié)

      GPT2模型已經(jīng)非常成熟,也有很多很好的開(kāi)源項(xiàng)目。筆者本著開(kāi)源之心,將代碼進(jìn)行整理,增加詳細(xì)注釋,希望可以幫助大家更好地理解代碼。也歡迎大家留言討論。
      參考
      GPT2-Chinese:https://github.com/Morizeyao/GPT2-Chinese
      GPT2-chitchat:https://github.com/yangjianxin1/GPT2-chitchat
      CDial-GPT:https://github.com/thu-coai/CDial-GPT
      GPT2:https://github.com/ConnorJL/GPT2
      transformers:https://github.com/huggingface/transformers

        本站是提供個(gè)人知識(shí)管理的網(wǎng)絡(luò)存儲(chǔ)空間,所有內(nèi)容均由用戶發(fā)布,不代表本站觀點(diǎn)。請(qǐng)注意甄別內(nèi)容中的聯(lián)系方式、誘導(dǎo)購(gòu)買等信息,謹(jǐn)防詐騙。如發(fā)現(xiàn)有害或侵權(quán)內(nèi)容,請(qǐng)點(diǎn)擊一鍵舉報(bào)。
        轉(zhuǎn)藏 分享 獻(xiàn)花(0

        0條評(píng)論

        發(fā)表

        請(qǐng)遵守用戶 評(píng)論公約

        類似文章 更多