如何评估预测值与真实值之间的匹配质量
toqiye 2024-10-18 13:57 2205 浏览 0 评论
1 代码实现
来自opencompass/opencompass/datasets/teval/evaluators/planning_evaluator.py
这段代码是一个比较复杂的过程,用于评估预测计划与真实计划之间的匹配质量。它包括多个关键步骤:计算相似度分数、创建匹配图、执行匹配算法、生成匹配映射,以及使用最长递增子序列(LIS)算法来计算正确匹配的节点数。
预测值如下
“predictions”: “[{\n “id”: 0,\n “name”: “AirbnbSearch.search_property_by_place”,\n “args”: “{‘place’: ‘柏林’}”\n}, {\n “id”: 1,\n “name”: “AirbnbSearch.get_property_details”,\n “args”: “{‘propertyId’: 1}”\n}, {\n “id”: 2,\n “name”: “AirbnbSearch.get_property_details”,\n “args”: “{‘propertyId’: 2}”\n}, {\n “id”: 3,\n “name”: “AirbnbSearch.get_property_details”,\n “args”: “{‘propertyId’: 3}”\n}, {\n “id”: 4,\n “name”: “ArxivSearch.get_arxiv_article_information”,\n “args”: “{‘query’: ‘太阳能’}”\n}]”,
真实值如下
“references”: [
{
“name”: “AirbnbSearch.search_property_by_place”,
“id”: 0,
“args”: “{‘place’: ‘Berlin’}”
},
{
“name”: “AirbnbSearch.get_property_reviews”,
“id”: 1,
“args”: “{‘propertyId’: 36339325}”
},
{
“name”: “AirbnbSearch.get_property_details”,
“id”: 2,
“args”: “{‘propertyId’: 45475252}”
},
{
“name”: “AirbnbSearch.get_property_checkout_price”,
“id”: 3,
“args”: “{‘propertyId’: 47215807, ‘checkIn’: ‘2023-10-20’, ‘checkOut’: ‘2023-10-23’}”
},
{
“name”: “ArxivSearch.get_arxiv_article_information”,
“id”: 4,
“args”: “{‘query’: ‘solar energy’}”
},
{
“name”: “ArxivSearch.get_arxiv_article_information”,
“id”: 5,
“args”: “{‘query’: ‘solar energy’}”
},
{
“name”: “FinishAction”,
“id”: 6,
“args”: “{}”
}
],
最终匹配结果如下
“precision”: 0.8
def bertscore_match(self, pred_plan, gt_plan) -> dict:
"""
Calculate the similarity between predicted plan and golden answer,
A plan can be regarded a sequence of actions, and each action has a name and args.
Firstly, use bertscore to calculate pointwise similarity by:
similarity(u, v) = bertscore(u.name, v.name) * name_weight + bertscore(u.args, v.args) * args_weight;
Secondly, use Hungarian matching to match the points;
Finally, use LIS to calculate the number of matched nodes.
"""
if len(pred_plan) == 0 or len(gt_plan) == 0:
return {
'precision': 0,
'recall': 0,
'f1_score': 0
}
pred_plan = copy.deepcopy(sorted(pred_plan, key=lambda x: x['id']))
gt_plan = copy.deepcopy(sorted(gt_plan, key=lambda x: x['id']))
#Add end action
#Currently it is hard-code
if pred_plan[-1]['name'] == 'FinishAction':
pred_plan = pred_plan[:-1]
if gt_plan[-1]['name'] == 'FinishAction':
gt_plan = gt_plan[:-1]
#The total counts of nodes and edges.
len_pred = len(pred_plan)
len_gt = len(gt_plan)
bert_score_matrix = np.zeros((len_pred, len_gt))
name_pred, args_pred = [], []
name_gt, args_gt = [], []
for i in range(len_pred):
name_pred.append(pred_plan[i]['name'])
args_pred.append(str(pred_plan[i]['args']))
for i in range(len_gt):
name_gt.append(gt_plan[i]['name'])
args_gt.append(str(gt_plan[i]['args']))
name_pred_emb = self.sentence_model.encode(name_pred, convert_to_tensor=True)
name_gt_emb = self.sentence_model.encode(name_gt, convert_to_tensor=True)
args_pred_emb = self.sentence_model.encode(args_pred, convert_to_tensor=True)
args_gt_emb = self.sentence_model.encode(args_gt, convert_to_tensor=True)
name_cosine_scores = np.maximum(util.cos_sim(name_pred_emb, name_gt_emb).cpu().numpy(), 0)
args_cosine_scores = np.maximum(util.cos_sim(args_pred_emb, args_gt_emb).cpu().numpy(), 0)
for i in range(len_pred):
for j in range(len_gt):
bert_score_matrix[i][j] = \
name_cosine_scores[i][j] * self.name_weight \
+ args_cosine_scores[i][j] * self.args_weight
G = nx.Graph()
for i in range(len_pred):
for j in range(len_gt):
if bert_score_matrix[i][j] > self.match_threshold:
G.add_edge(i, str(j), weight=bert_score_matrix[i][j])
max_weight_matching = nx.max_weight_matching(G)
pred_to_gt_mapping = dict()
for key in max_weight_matching:
if type(key[0]) == int:
pred_to_gt_mapping[int(key[0])] = int(key[1])
else:
pred_to_gt_mapping[int(key[1])] = int(key[0])
#If a prediction node does not match any golden answer node, we mark the node as -1.
for i in range(len_pred):
if i not in pred_to_gt_mapping:
pred_to_gt_mapping[i] = -1
#Calculate how many nodes are matched by Longest Increasing Subsequence (LIS)
dp = np.ones(len_pred)
for i in range(len_pred):
for j in range(i):
if pred_to_gt_mapping[i] == -1 or pred_to_gt_mapping[j] == -1:
continue
if pred_to_gt_mapping[i] > pred_to_gt_mapping[j]:
dp[i] = max(dp[i], dp[j] + 1)
correct_count = int(max(dp))
precision = correct_count / len(pred_plan)
代码中关键变量的取值示例
gt_plan [{‘name’: ‘AirbnbSearch.search_property_by_place’, ‘id’: 0, ‘args’: “{‘place’: ‘Berlin’}”}, {‘name’: ‘AirbnbSearch.get_property_reviews’, ‘id’: 1, ‘args’: “{‘propertyId’: 36339325}”}, {‘name’: ‘AirbnbSearch.get_property_details’, ‘id’: 2, ‘args’: “{‘propertyId’: 45475252}”}, {‘name’: ‘AirbnbSearch.get_property_checkout_price’, ‘id’: 3, ‘args’: “{‘propertyId’: 47215807, ‘checkIn’: ‘2023-10-20’, ‘checkOut’: ‘2023-10-23’}”}, {‘name’: ‘ArxivSearch.get_arxiv_article_information’, ‘id’: 4, ‘args’: “{‘query’: ‘solar energy’}”}, {‘name’: ‘ArxivSearch.get_arxiv_article_information’, ‘id’: 5, ‘args’: “{‘query’: ‘solar energy’}”}]
pred_plan [{‘name’: ‘AirbnbSearch.search_property_by_place’, ‘id’: 0, ‘args’: “{‘place’: ‘柏林’}”}, {‘name’: ‘AirbnbSearch.get_property_details’, ‘id’: 1, ‘args’: “{‘propertyId’: 1}”}, {‘name’: ‘AirbnbSearch.get_property_details’, ‘id’: 2, ‘args’: “{‘propertyId’: 2}”}, {‘name’: ‘AirbnbSearch.get_property_details’, ‘id’: 3, ‘args’: “{‘propertyId’: 3}”}, {‘name’: ‘ArxivSearch.get_arxiv_article_information’, ‘id’: 4, ‘args’: “{‘query’: ‘太阳能’}”}]
args_gt 和 args_pred 分别是真实参数列表和预测参数列表。
name_gt 和 name_pred 分别是真实动作名称列表和预测动作名称列表。
args_gt [“{‘place’: ‘Berlin’}”, “{‘propertyId’: 36339325}”, “{‘propertyId’: 45475252}”, “{‘propertyId’: 47215807, ‘checkIn’: ‘2023-10-20’, ‘checkOut’: ‘2023-10-23’}”, “{‘query’: ‘solar energy’}”, “{‘query’: ‘solar energy’}”]
args_pred [“{‘place’: ‘柏林’}”, “{‘propertyId’: 1}”, “{‘propertyId’: 2}”, “{‘propertyId’: 3}”, “{‘query’: ‘太阳能’}”]
name_gt [‘AirbnbSearch.search_property_by_place’, ‘AirbnbSearch.get_property_reviews’, ‘AirbnbSearch.get_property_details’, ‘AirbnbSearch.get_property_checkout_price’, ‘ArxivSearch.get_arxiv_article_information’, ‘ArxivSearch.get_arxiv_article_information’]
name_pred [‘AirbnbSearch.search_property_by_place’, ‘AirbnbSearch.get_property_details’, ‘AirbnbSearch.get_property_details’, ‘AirbnbSearch.get_property_details’, ‘ArxivSearch.get_arxiv_article_information’]
max_weight_matching {(‘5’, 4), (‘1’, 3), (2, ‘3’), (‘2’, 1), (0, ‘0’)}
pred_to_gt_mapping {4: 5, 3: 1, 2: 3, 1: 2, 0: 0}
dp array([1., 2., 3., 2., 4.])
precision 0.8
2 计算 BertScore 矩阵
- 双层循环:通过两层循环,遍历预测列表(len_pred)和真实列表(len_gt)的每个元素。这是为了比较每个预测元素与每个真实元素的相似度。
- 相似度计算:使用余弦相似度分数(name_cosine_scores 和 args_cosine_scores),结合设定的权重(name_weight 和 args_weight),计算出预测和真实元素间的综合相似度。这个分数存储在 bert_score_matrix 矩阵中。
3 创建图并执行匹配
- 图的创建:使用网络X库创建一个图 G。图中的每个节点代表一个预测或真实元素。
- 添加边:根据 bert_score_matrix 中的相似度分数,如果分数超过预设的阈值(match_threshold),则在对应的预测节点和真实节点之间添加一条边。边的权重是相似度分数。
- 执行匹配:使用网络X的 max_weight_matching 方法找出图中的最大权重匹配,这个方法会尝试找出一种匹配方式,使得所有匹配的权重总和最大。
4 生成匹配映射
- 创建映射表:通过解析匹配结果,创建一个映射表 pred_to_gt_mapping,记录每个预测节点所匹配到的真实节点索引。
- 未匹配节点标记:遍历所有预测节点,如果某个节点没有在匹配结果中,将其映射值标记为 -1,表示该预测节点没有找到对应的匹配。
5 计算正确匹配的节点数
- 初始化动态规划数组:dp 数组用于存储每个节点为终点的最长递增子序列的长度。
- 计算 LIS:双层循环遍历 pred_to_gt_mapping,如果两个节点都有有效的匹配,并且它们的映射值是递增的,则更新 dp 数组。
- 计算正确匹配的总数:最后,通过取 dp 数组中的最大值得到正确匹配的节点总数。
这段代码实现的是一个动态规划算法,用于计算最长递增子序列(LIS)的长度。在这个上下文中,它被用来确定预测到真实匹配项之间的递增顺序的长度,从而评估预测序列的质量。
最长递增子序列(LIS)
最长递增子序列是一个在数列中找出一组递增排序的最长序列的问题。在这个应用中,pred_to_gt_mapping 映射了预测节点到真实节点的索引,而 LIS 用于找出这些映射中的最长有序(递增)关系。
初始化
dp 数组用于存储每个元素作为序列结束点时的最长递增子序列的长度。数组初始化为1,表示每个元素自身至少可以构成长度为1的序列。
动态规划过程
- 外层循环 (for i in range(len_pred)):遍历每一个预测节点。
- 内层循环 (for j in range(i)):对于每个节点 i,遍历所有在其前面的节点 j。
- 条件判断:if pred_to_gt_mapping[i] == -1 or pred_to_gt_mapping[j] == -1: 如果任一节点未匹配任何真实节点(标记为 -1),则跳过当前循环迭代,因为未匹配的节点不能用来形成有效的递增序列。if pred_to_gt_mapping[i] > pred_to_gt_mapping[j]: 如果节点 i 的匹配索引大于节点 j 的匹配索引,说明这两个节点可以形成一个递增的序列。此时,需要更新 dp[i] 的值。
更新dp[i]
- dp[i] = max(dp[i], dp[j] + 1): 这一步是核心的动态规划更新逻辑。如果节点 i 和节点 j 可以形成递增序列,那么以 i 结尾的最长递增子序列长度可以从 j 的序列长度加1得到。更新 dp[i] 为其自身的值和 dp[j] + 1 之间的最大值。
计算最长递增子序列的总长度
- correct_count = int(max(dp)): 通过取 dp 数组中的最大值,我们得到整个预测序列中最长的递增子序列的长度,这代表了正确匹配的最大顺序数量。
示例说明
给定的映射 {4: 5, 3: 1, 2: 3, 1: 2, 0: 0} 中:
- 开始时每个点自身至少构成长度为1的序列。
- 当检查到节点 3 和节点 1 时,因为 1 的映射是 2,3 的映射是 1,且 1 < 2,不满足递增关系,因此不更新 dp[3]。
- 检查到节点 4 和之前的节点时,会发现它与之前的所有有效映射节点(0, 1, 2, 3)都形成了递增序列,所以它的 dp 值将基于最长的递增序列更新。
这样,dp 数组最终记录了以每个节点结尾的最长递增子序列的长度,其最大值代表了整个序列中正确匹配的最大顺序数。
6 通俗解释
这个过程可以通过一个比喻来通俗地解释:想象一下你在组织一个舞会,其中的每个舞者(动作)有他们特定的舞步(动作名称)和舞伴(参数)。你的任务是为每个舞者找到最佳的舞伴,以便他们的舞步能够尽可能完美地配合。
生成 BertScore 矩阵:
步骤1: 将每个舞者的舞步和他们选择的舞伴看作是一个计划中的一个动作。你需要记录下每个舞者的名字和他们喜欢的舞伴的类型。
步骤2: 使用一种特殊的“评分系统”(self.sentence_model.encode),这个系统可以评估每位舞者的舞步和舞伴选择与其他舞者的相似度。这就像给他们的舞步和舞伴选择打分,分数越高说明他们越可能搭配得好。
步骤3: 计算舞步和舞伴之间的配合度(余弦相似度),这可以帮助你看到哪些舞者可能是理想的舞伴。
步骤4: 根据舞步和舞伴的重要性(name_weight 和 args_weight),合成一个最终得分(BertScore)。这就像决定在评估一个舞者的时候,他们的舞步技巧和选择的舞伴哪个更重要。
匈牙利匹配算法(Hungarian Matching):
步骤5: 使用生成的得分(BertScore 矩阵)来创建一个舞会的配对图。在这个图中,每个节点代表一个舞者,每条边代表两个舞者可能成为舞伴的得分。
步骤6: 使用一个算法(匈牙利匹配算法,通过网络X的 max_weight_matching 方法实现)来找出这场舞会中所有可能的最佳舞伴组合。这保证了每个舞者都找到了最匹配的舞伴,而且整个舞会的舞伴匹配得分最高。
计算匹配和最长递增子序列 (LIS):
步骤7: 为了记录哪些舞者找到了舞伴,你创建一个映射表 (pred_to_gt_mapping),记录每个舞者匹配的结果。
步骤8: 如果某个舞者找不到任何合适的舞伴,你就把他们标记为 -1。
步骤9: 最后,使用最长递增子序列(LIS)的方法来计算实际上有多少舞者与他们的舞伴舞步完美匹配。这个数值能帮你了解舞会的整体配对成功率。
这整个过程就像是组织一场完美的舞会,每个舞者都能找到与之舞步完美匹配的舞伴,使得整个舞会的舞步和谐统一。
- 上一篇:一面算法总结(经典算法面试题)
- 下一篇:数码管使用方法详解(数码管连接方法)
相关推荐
- 国产web端开源ui组件-通用前端ui界面组件库
-
写个序吧:随着工作的不断深入,越来越发现很多好的前端开源项目都来自于国外,国产的开源项目很多时候面临叫好不叫座甚至有很多就消失不见了。开源和创新,不仅仅是需要我们的参与同样也需要我们不断地传播,因此才...
- 微信小程序商城项目,包括,分销,团购,秒杀,优惠券惠源码分享
-
源码获取,关注转发文章之后私信回复【源码】即可免费获取到!小程序商城,包括:分销(支持三级)、团购(拼多多模式)、秒杀、优惠券、等功能技术框架核心框架:SpringFramework4安全框架:A...
- VUE3前端开发入门系列教程
-
一直以来使用ThinkJS开发,使用Semantic-UI手写代码,又缺少一些table等插件,好累。平时使用NodeJS开发后端较多,一直有接触VUE想法,总是不得入门(可能是思维固化了),再次深入...
- 支持分销、团购、秒杀、优惠券、微信商城项目,源码免费分享
-
小程序商城,包括:分销(支持三级)、团购(拼多多模式)、秒杀、优惠券、等功能如果您需要用到这个微信小程序的源码,欢迎关注转发之后私信【源码】来免费获取到!面向对象Open-Shop是企业在创立初期很好...
- 在Gitee获8.5k Star,做微信小程序商城看这一个开源项目就够了
-
商城系统是小程序中比较热门的类型,许多开发者在寻找商城类小程序项目时,都会遇到一些声称「开源」但是并不是完全开源,有时候还会收费的项目。今天Gitee介绍的这款微信小程序商城项目就是一款从前端到后...
- 七款国内免费开源PHP CMS推荐,无限制,可商用
-
自织梦cms收费后,很多使用dedecms的站长都转移到了别的cms系统上,上一期也给大家分享了几款国外开源cms系统,今天来给大家分享几款国内的免费且可商用的phpcms系统:PbootCmsPb...
- VUE3前端开发入门系列教程二:使用iView框架辅助开发
-
1、安装iView新框架,支持VUE3npminstallview-ui-plus2、编辑src/main.js,添加以下内容,导入js和css到项目importViewUIPlusfrom...
- TS 真正比 JS 强大的那些特性
-
在前端开发领域,JavaScript(JS)一直是当之无愧的武林盟主,凭借灵活多变的特性和超广泛的兼容性打下大片江山。然而,随着前端应用日益复杂,TypeScript(TS)这位后起之秀崛起,以独特优...
- 自写一个函数将js对象转为Ts的Interface接口
-
如今的前端开发typescript已经成为一项必不可以少的技能了,但是频繁的定义Interface接口会给我带来许多工作量,我想了想如何来减少这些非必要且费时的工作量呢,于是决定写一个函数,将对象放进...
- 如何优雅地校验后端接口数据,不做前端背锅侠
-
背景最近新接手了一批项目,还没来得及接新需求,一大堆bug就接踵而至,仔细一看,应该返回数组的字段返回了null,或者没有返回,甚至返回了字符串"null"???这我能忍?我立刻截...
- 正点原子I.MX6U嵌入式Linux C应用编 第十八章 输入设备应用编程
-
输入设备应用编程本章学习...
- Python时间序列分析:使用TSFresh进行自动化特征提取
-
TSFresh(基于可扩展假设检验的时间序列特征提取)是一个专门用于时间序列数据特征自动提取的框架。该框架提取的特征可直接应用于分类、回归和异常检测等机器学习任务。TSFresh通过自动化特征工程流程...
- 人教版八下数学第十九章《一次函数》辅导(6)一次函数(1)
-
人教版八下数学第十九章《一次函数》辅导(6)一次函数(1)一、生活中的一次函数探究1(1)一个小球由静止开始沿一个斜坡向下滚动,其速度每秒增加2m/s,若小球的速度为vm/s,运动时间为ts,求v关于...
- 笔记|Simulink中S函数的设计
-
S函数的简介S函数是Simulink中提供给用户的一个自定义模块,由于在研究过程中经常需要复杂的算法设计,Simulink中提供的模块无法满足使用,就需要用编程的形式设计出S函数模块,然后嵌入到系统中...
- 初探HarmonyOS开发,ArkTS语言初看
-
最近在研究HarmonyOS(鸿蒙操作系统)开发,HarmonyOS(鸿蒙操作系统)想必大家都不陌生了,但是我也是在该系统发布许久后才首次尝试上手开发,因为有写java后台的经验,也算是能堪堪上手。据...
你 发表评论:
欢迎- 一周热门
-
-
如何评估预测值与真实值之间的匹配质量
-
很香的几款开源免费的流程设计器
-
如何解决npm安装依赖报错ERESOLVE unable to resolve dependency tree
-
启动Autodesk软件时弹出FLEXnet License Finder对话框,这么办!
-
超详细的cmder工具介绍及功能、快捷键说明
-
硬核!Rust异步编程方式重大升级:新版Tokio如何提升10倍性能详解
-
基于 Vue 网页版IM聊天组件Lemon-IMUI
-
Vue组件库设计 | Vue3组件在线交互解释器
-
vue3 学习笔记(九)——script setup 语法糖用了才知道有多爽
-
「推荐」vue-code-view:一个在线编辑、实时预览的代码交互组件
-
- 最近发表
- 标签列表
-
- systemproperties (65)
- npm版本管理 (61)
- localhost:15672 (59)
- materialtheme (86)
- node-ssh (68)
- 图床搭建 (62)
- vue3addeventlistener (60)
- mybatisselectone (78)
- css圆形进度条 (69)
- androidble蓝牙开发 (62)
- android-gif-drawable (60)
- appender-ref (64)
- springbootmockito (68)
- gsonfastjson (59)
- 依赖注入的方式 (62)
- cookie跨域共享 (63)
- easyexcel导出图片 (77)
- dp数组 (61)
- js获取兄弟节点 (68)
- sysctl-a (60)
- window.target (62)
- apimodel注解的作用 (60)
- window.onerror (66)
- java链表listnode (59)
- springmvc教程 (65)