def collate_func(batch_data):
"""
DataLoader所需的collate_fun函数,将数据处理成tensor形式
Args:
batch_data: batch数据
Returns:
"""
batch_size = len(batch_data)
# 如果batch_size为0,则返回一个空字典
if batch_size == 0:
return {}
input_ids_list, token_type_ids_list = [], []
# 获取一个batch数据中的最大长度
max_len = max([len(instance["input_ids"]) for instance in batch_data])
for instance in batch_data:
# 按照batch中的最大数据长度,对数据进行padding填充
input_ids_temp = instance["input_ids"]
input_ids_temp.extend([0]*(max_len-len(instance["input_ids"])))
token_type_ids_temp = instance["token_type_ids"]
token_type_ids_temp.extend([0] * (max_len - len(instance["token_type_ids"])))
# 将list数据转换为tensor数据
# input_ids_list.append(torch.from_numpy(np.array(input_ids_temp, dtype=np.int32)).long())
# token_type_ids_list.append(torch.from_numpy(np.array(token_type_ids_temp, dtype=np.int32)).long())
input_ids_list.append(torch.tensor(input_ids_temp, dtype=torch.long))
token_type_ids_list.append(torch.tensor(token_type_ids_temp, dtype=torch.long))
return {"input_ids": pad_sequence(input_ids_list, batch_first=True, padding_value=0),
"token_type_ids": pad_sequence(token_type_ids_list, batch_first=True, padding_value=0)}
这个使用pad_sequence了,是不是就不需要在上面extend了?