窝牛号

Mamba线性时间序列框架超越Transformer模型?注意力机制如何翻身

最近一个火爆AI圈的一个线性时间序列模型框架Mamba引起了大家的讨论,由于其性能在相关的实验上超越了transformer模型,让人不禁讨论注意力机制是否是大模型唯一的出路。transformer模型与mamba模型的大战也是一触即发,很多小伙伴们已经开始尝试把项目转移到mamba框架上。且mamba的开源代码已经上传到GitHub上,小伙伴们可以运行代码尝试。

目前几乎所有的大语言模型都是基于transformer模型框架。许多二次时间序列架构(例如线性注意力、门控卷积和循环模型以及结构化状态空间模型(SSM))已被开发来解决 Transformers 在长序列上的计算效率低下问题,但它们在重要模态上的表现并不好。

Mamba 是一种新的状态空间模型架构,在语言建模等信息密集数据上显示出良好的性能,而以前的二次模型在 Transformers 方面存在不足。Mamba基于结构化状态空间模型的,并使用FlashAttention进行高效的硬件设计和实现。

Mamba 享有快速推理(吞吐量比 Transformer 高 5 倍)和序列长度线性缩放,并且其性能在高达百万长度序列的实际数据上得到了很多的提高。 作为通用序列模型骨干,Mamba 在语言、音频和基因组学等多种任务上实现了最先进的性能。 在语言建模方面, Mamba-3B 模型在预训练和下游任务方面都优于相同大小的 Transformer模型,并且与两倍大小的 Transformer 相匹配。

Mamba模型采用门控MLP单元,并与归一化残差连接组成新的Mamba基础模型,并使用此模型搭建整体的Mamba模型,整个模型框架并未采用transformer模型的注意力机制。

Mamba代码已经开源,且已经在hugging face上面发布了不同尺寸大小的模型,小伙伴可以使用预训练模型进行相关的测试以及自己模型的测试。

在使用本期代码前,首先需要保证自己有GPU显卡以及相关驱动,并安装了python以及pytorch相关库,并安装Mamba。

!pip install causal-conv1d==1.0.0 !pip install mamba-ssm==1.0.1 !export LC_ALL=&34; !export LD_LIBRARY_PATH=&34; !export LIBRARY_PATH=&34; !ldconfig /usr/lib64-nvidia

然后我们import相关的库,并加载Mamba预训练模型,代码运行时,会自动下载相关的模型。

import torch from transformers import AutoTokenizer from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel device = &34; tokenizer = AutoTokenizer.from_pretrained(&34;) tokenizer.eos_token = &34; tokenizer.pad_token = tokenizer.eos_token tokenizer.chat_template = AutoTokenizer.from_pretrained(&34;).chat_template model = MambaLMHeadModel.from_pretrained(&34;, device=&34;, dtype=torch.float16)

运行以上代码,模型会自动下载相关的模型,无需自己手动配置,比较方便使用。

tokenizer_config.json: 100%4.79k/4.79k [00:00<00:00, 246kB/s] tokenizer.json: 100%2.11M/2.11M [00:00<00:00, 6.42MB/s] special_tokens_map.json: 100%131/131 [00:00<00:00, 9.12kB/s] Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. tokenizer_config.json: 100%1.43k/1.43k [00:00<00:00, 103kB/s] tokenizer.model: 100%493k/493k [00:00<00:00, 10.3MB/s] tokenizer.json: 100%1.80M/1.80M [00:00<00:00, 7.78MB/s] added_tokens.json: 100%42.0/42.0 [00:00<00:00, 2.30kB/s] special_tokens_map.json: 100%168/168 [00:00<00:00, 12.8kB/s] config.json: 100%201/201 [00:00<00:00, 11.0kB/s] pytorch_model.bin: 100%5.55G/5.55G [02:07<00:00, 80.5MB/s]

待模型下载完成后,使用MambaLMHeadModel.from_pretrained加载预训练模型,并使用预训练模型建立一个model函数。

messages = [] while True: user_message = input(&34;) messages.append(dict( role=&34;, content=user_message )) input_ids = tokenizer.apply_chat_template(messages, return_tensors=&34;, add_generation_prompt=True).to(&34;) out = model.generate(input_ids=input_ids, max_length=2000, temperature=0.9, top_p=0.7, eos_token_id=tokenizer.eos_token_id) decoded = tokenizer.batch_decode(out) messages.append(dict( role=&34;, content=decoded[0].split(&34;)[-1])) print(&34;, decoded[0].split(&34;)[-1])

模型建立完成后,我们就可以跟模型进行聊天了,由于预训练模型的数据限制,此模型并无法跟ChatGPT一样,什么问题都可以回答。首先第一个问题便是编程问题,其编码问题可以完美的解决。

Your message: do you know python Model: Yes, I do know about Python. Python is a popular programming language that was developed by Guido van Rossum in the early 1990s. It is a high-level, general-purpose programming language that is designed to be easy to read, write, and modify. Python is used in a wide range of applications, including web development, data analysis, scientific computing, and machine learning.<|endoftext|> Your message: Please use python to write a bubble sort Model: Sure, here&34;1+1=2&34;John&文章首发挑战赛#

https://arxiv.org/abs/2312.00752 https://github.com/state-spaces/mamba https://huggingface.co/state-spaces

本站所发布的文字与图片素材为非商业目的改编或整理,版权归原作者所有,如侵权或涉及违法,请联系我们删除

窝牛号 wwww.93ysy.com   沪ICP备2021036305号-1