探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力

Grouped-query Attention,简称GQA

分组查询注意力(Grouped-query Attention,简称GQA)是多查询和多头注意力的插值。它在保持与多查询注意力相当的处理速度的同时,实现了与多头注意力相似的质量。

在这里插入图片描述

自回归解码的标准做法是缓存序列中先前标记的键和值,以加快注意力计算的速度。

  • 然而,随着上下文窗口或批量大小的增加,多头注意力(Multi-Head Attention,简称MHA)模型中键值缓存(Key-Value Cache,简称KV Cache)的大小所关联的内存成本显著增加。

  • 多查询注意力(Multi-Query Attention,简称MQA)是一种机制,它对多个查询仅使用单个键值头,这可以节省内存并大幅加快解码器的推理速度。

  • Llama(一种模型)整合了GQA,以解决在Transformer模型自回归解码期间的内存带宽挑战。主要问题源于GPU进行计算的速度比它们将数据移入内存的速度快。在每个阶段都需要加载解码器权重和注意力键,这消耗了大量的内存。

在这里插入图片描述
在这里插入图片描述

class SelfAttention(nn.Module): 
    def  __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # Indicates the number of heads for the queries
        self.n_heads_q = args.n_heads
        # Indiates how many times the heads of keys and value should be repeated to match the head of the Query
        self.n_rep = self.n_heads_q // self.n_kv_heads
        # Indicates the dimentiona of each head
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))

    def forward(self, x: torch.Tensor, start_pos: int, freq_complex: torch.Tensor):
        batch_size, seq_len, _ = x.shape #(B, 1, dim)
        # Apply the wq, wk, wv matrices to query, key and value
        # (B, 1, dim) -> (B, 1, H_q * head_dim)
        xq = self.wq(x)
        # (B, 1, dim) -> (B, 1, H_kv * head_dim)
        xk = self.wk(x)
        xv = self.wv(x)

        # (B, 1, H_q * head_dim) -> (B, 1, H_q, head_dim)
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        # (B, 1, H_kv * head_dim) -> (B, 1, H_kv, head_dim)
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

        # Apply the rotary embeddings to the keys and values
        # Does not chnage the shape of the tensor
        # (B, 1, H_kv, head_dim) -> (B, 1, H_kv, head_dim)
        xq = apply_rotary_embeddings(xq, freq_complex, device=x.device)
        xk = apply_rotary_embeddings(xk, freq_complex, device=x.device)

        # Replace the enty in the cache for this token
        self.cache_k[:batch_size, start_pos:start_pos + seq_len] = xk
        self.cache_v[:batch_size, start_pos:start_pos + seq_len] = xv

        # Retrive all the cached keys and values so far
        # (B, seq_len_kv, H_kv, head_dim)
        keys = self.cache_k[:batch_size, 0:start_pos + seq_len]
        values = self.cache_v[:batch_size, 0:start_pos+seq_len] 

        # Repeat the heads of the K and V to reach the number of heads of the queries
        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        # (B, 1, h_q, head_dim) --> (b, h_q, 1, head_dim)
        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # (B, h_q, 1, head_dim) @ (B, h_kv, seq_len-kv, head_dim) -> (B, h_q, 1, seq_len-kv)
        scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        # (B, h_q, 1, seq_len) @ (B, h_q, seq_len-kv, head_dim) --> (b, h-q, q, head_dim)
        output = torch.matmul(scores, values)

        # (B, h_q, 1, head_dim) -> (B, 1, h_q, head_dim) -> ()
        output = (output.transpose(1,2).contiguous().view(batch_size, seq_len, -1))
        return self.wo(output) # (B, 1, dim) -> (B, 1, dim)

系列博客

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(一)
https://duanzhihua.blog.csdn.net/article/details/138208650
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(二)
https://duanzhihua.blog.csdn.net/article/details/138212328

探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(三)KV缓存
https://duanzhihua.blog.csdn.net/article/details/138213306
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/577091.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【35分钟掌握金融风控策略10】风控策略部署2

目录 策略部署 决策引擎系统简介 基于决策引擎进行策略部署 策略部署结果验证 知识点补充 测试验证 回溯比对 策略部署 策略主要部署在决策引擎上进行风险决策,接下来分别介绍决策引擎系统,以及基于决策引擎进行策略部署的相关内容。 决策引擎系…

java-Spring-(MyBatis框架-xml管理)

目录 前置条件 xml与注解比较 1.1 xml定义 1.2 和SQL注解比较 建包准备 插入数据 ​编辑 更新数据 删除数据 查询数据 查看单字段查询 🏷💣前置条件 创建一个spring boot 初始化的项目 🏷💣xml与注解比较 1.1 xml定义 …

微信小程序简单实现购物车功能

微信小程序简单实现购物车结算和购物车列表展示功能 实现在微信小程序中对每一个购物车界面的商品订单,进行勾选结算和取消结算的功能,相关界面截图如下: 具体实现示例代码为: 1、js代码: Page({/*** 页面的初始数…

SpringCloudAlibaba:2.1nacos

概述 概述 简介 Nacos是阿里巴巴开源的服务注册中心以及配置中心 Nacos注册中心Eureka 服务配置Config 服务总线Bus 官网 Nacos官网 | Nacos 官方社区 | Nacos 下载 | Nacos 名字由来 Naming:名字 Configurations:配置 Service:服务 功能…

【基础篇】Git 基础命令与核心概念

✅作者简介:大家好,我是小杨 📃个人主页:「小杨」的csdn博客 🐳希望大家多多支持🥰一起进步呀! 一,Git 初识 1.1,问题引入 不知道你工作或学习时,有没有遇到…

Centos8操作系统安装mysql5.7版本以及报错解决

目录 一、卸载MySql 1.首先查看已安装的mysql 2.逐个或者执行一下命令统一卸载掉 注意: 3. 卸载其他相关文件 二、安装MySql 1.安装mysql的rpm源 2.安装MySql 如果遇到以下错误: 问题一: 解决方法: 问题二、 解决方法&#xff1…

国产麒麟v10系统下打包electron+vue程序,报错unknown output format set

报错如下: 报错第一时间想到可能是代码配置原因报错,查看代码似乎感觉没啥问题 又查看具体报错原因可能是因为icon的原因报错,后面查阅发现ico在各系统平台会不兼容,也就是ico是给win下使用的,此处改下图标格式就ok&am…

1、Qt简介

文章目录 前言一、pySide2 / pySide6 ,PyQt5 / PyQt6二、安装包1 安装pyside22 安装pyqt5三、从一个简单的例子开始三、界面动作处理---信号(signal)与槽(slot)(Qt最核心的机制)--- 绑定事件封装到类中总结前言 参考文章:Qt简介 本文开始就开始进入到qt的开发笔记书写…

【论文解读】QUEST: Query Stream for Practical Cooperative Perception

QUEST 摘要引言QUERY COOPERATION PARADIGMQUEST FRAMEWORKA. Overall ArchitectureB. Cross-agent Query Interaction 实验结论 摘要 合作感知通过提供额外的视点和扩展感知领域,可以有效地提高个体感知性能。现有的合作模式要么是可解释的(结果合作),…

计算机视觉——OpenCV 使用分水岭算法进行图像分割

分水岭算法 分水岭算法:模拟地理形态的图像分割 分水岭算法通过模拟自然地形来实现图像中物体的分类。在这一过程中,每个像素的灰度值被视作其高度,灰度值较高的像素形成山脊,即分水岭,而二值化阈值则相当于水平面&am…

LabVIEW高效目标跟踪系统

LabVIEW高效目标跟踪系统 随着机器视觉技术的飞速发展,设计和实现高效的目标跟踪系统成为了众多领域关注的焦点。基于LabVIEW平台,结合NI Vision机器视觉库,开发了一种既高效又灵活的目标跟踪系统。通过面向对象编程方法和队列消息处理器程序…

以更多架构核心专利,推进 SDS 产业创新创造

今天是第 24 个世界知识产权日,今年世界知识产权日活动的主题是:“知识产权和可持续发展目标:立足创新创造,构建共同未来。” 这也正是 XSKY 在软件定义存储领域的目标之一。以“数据常青”为使命的 XSKY,始终立足于软…

济宁市中考报名照片要求及手机拍照采集证件照方法

随着中考报名季的到来,并且进入了中考报名演练阶段,济宁市的广大考生和家长都开始忙碌起来。报名过程中,上传一张符合要求的证件照是必不可少的环节。本文将详细介绍济宁市中考报名照片的具体要求,并提供一些实用的手机拍照采集证…

LeetCode in Python 74/240. Search a 2D Matrix I/II (搜索二维矩阵I/II)

搜索二维矩阵I其实可以转换为搜索一维数组,原因在于,只要先确定搜索的整数应该在哪一行,即可对该行进行二分查找。 搜索二维矩阵II中矩阵元素排列方式与I不同,但思想大致相同。 目录 LeetCode in Python 74. LeetCode in Pyth…

html表格导出为word文档,导出的部分表格内无法填写文字

导出技术实现:fileSaver.jshtml-docx-js 1.npm安装 npm install --save html-docx-js npm install --save file-saver 2.页面引入 import htmlDocx from html-docx-js/dist/html-docx; import saveAs from file-saver;components: {htmlDocx,saverFile, }, 3.页…

(MSFT.O)微软2024财年Q3营收619亿美元

在科技的浩渺宇宙中,一颗璀璨星辰再度闪耀其光芒——(MSFT.O)微软公司于2024财政年度第三季展现出惊人的财务表现,实现总营业收入达到令人咋舌的6190亿美元。这一辉煌成就不仅突显了微软作为全球技术领导者之一的地位,更引发了业界内外对这家…

Vue从0-1学会如何自定义封装v-指令

文章目录 介绍使用1. 理解指令2. 创建自定义指令3. 注册指令4. 使用自定义指令5. 自定义指令的钩子函数6. 传递参数和修饰符7. 总结 介绍 自定义封装 v-指令是 Vue.js 中非常强大的功能之一,它可以让我们扩展 Vue.js 的模板语法,为 HTML 元素添加自定义行…

在Elasticsearch 7.9.2中安装IK分词器并进行自定义词典配置

Elasticsearch是一个强大的开源搜索引擎,而IK分词器是针对中文文本分析的重要插件。本文将引导您完成在Elasticsearch 7.9.2版本中安装IK分词器、配置自定义词典以及验证分词效果的全过程。 步骤一:下载IK分词器 访问IK分词器的GitHub发布页面&#xf…

【网络编程】TCP流套接字编程 | Socket类 | ServerSocket类 | 文件资源泄露 | TCP回显服务器 | 网络编程

文章目录 TCP流套接字编程1.ServerSocket类2.Socket类3.文件资源泄露4.**TCP回显服务器** TCP流套接字编程 ​ ServerSocket类和Socket类这两个类都是用来表示socket文件(抽象了网卡这样的硬件设备)。 TCP是面向字节流的,传输的基本单位是b…

MySQL B+索引的工作原理及应用

引言 在数据库系统中,索引是优化查询、提高性能的关键技术之一。特别是在MySQL数据库中,B树索引作为最常用的索引类型,对数据库性能有着至关重要的影响。本文旨简单解析MySQL中B树索引的工作原理,帮助学生朋友们更好地理解和利用…