随着AI模型的参数量越来越大,对算力的需求也水涨船高。
比如最近,Llama-3.1登上了最强开源大模型的宝座,但超大杯405B版本的内存就高达900多GB,这对算力构成了更加苛刻的挑战。
如何降低算力的使用成本和使用门槛,已经成为许多公司寻求突破的关键。Felafax就是其中的一家创业公司,致力于简化AI训练集群的搭建流程。
NikhilSonti和NikhinSonti创立了Felafax,他们的口号是在构建开源AI平台,为下一代AI硬件服务,将机器学习的训练成本降低30%。
与英伟达相比,AMD的GPU,尤其是MI300X系列,提供了更高的性价比,按每美元计算,其性能表现更为出色。
最近,Felafax的联合创始人NikhilSonti发布了一篇博客,详细分享了如何通过8张AMDMI300XGPU和JAX微调LLaMA3.1405B模型的方法,所有代码现已开源。
Github链接:https://github.com/felafax/felafax
机器之心对博客内容进行了不改变原意的编译、整理,以下是博客内容:
JAX尤其适合非英伟达硬件
JAX是一个强大的机器学习库,结合了类似NumPy的API、自动微分功能以及Google的XLA编译器。它在模型并行化方面提供了优秀的API,因此非常适合像LLaMA3.1405B这样的超大模型训练。
在使用AMD硬件时,JAX有几个明显的优势:
多硬件并行支持:JAX采用XLA(加速线性代数)编译器,将计算编译为硬件无关的中间表示(HLO),这意味着同样的JAX代码无需修改便可高效运行在不同硬件后端,包括AMDGPU。独立于底层硬件:XLA编译器的优化策略是通用的,不针对某个特定的硬件平台。这使得任何支持XLA的硬件设备(如CPU、GPU、TPU)都能受益于这些优化,获得更好的性能表现。极高的适应性:从NVIDIA转移到AMD(或其他硬件)时,JAX只需做极少的代码改动。而相较之下,PyTorch与英伟达的CUDA生态系统紧密耦合,迁移过程相对复杂。
因此,JAX成为了我们在非英伟达硬件上的最佳选择。
拉取Docker镜像:
dockerpullrocm/jax:latest
启动Docker容器:
#PulltheDockerImage:
dockerpullrocm/jax:latest
#StarttheDockerContainer:
dockerrun-it-w/workspace--device=/dev/kfd--device=/dev/dri--group-addvideo
--cap-add=SYS_PTRACE--security-optseccomp=unconfined--shm-size16Grocm/jax:latest
#VerifytheInstallation:
python3-c/'importjax;print(jax.devices())/'
验证安装
python3-c/'importjax;print(jax.devices())/'
训练使用了一个配备了8张AMDMI300xGPU的AMD节点。每张MI300x拥有192GB的HBM3内存,性能表现与最新的英伟达H100GPU相比非常出色。
与英伟达H100的比较,来源:TensorWave
训练LLaMA405B:性能与可扩展性
使用JAX,可以成功地在AMDGPU上训练LLaMA405B模型。我们使用LoRA微调,将所有模型权重和LoRA参数都设为bfloat16,LoRArank设为8,LoRAalpha设为16:
模型大小:LLaMA模型的权重占用了约800GB的显存。LoRA权重+优化器状态:大约占用了400GB的显存。显存总使用量:占总显存的77%,约1200GB。限制:由于405B模型的规模过大,batch大小和序列长度的空间有限,使用的batchsize为16,序列长度为64。JIT编译:由于空间限制,无法运行JIT编译版本;它可能需要比急切模式稍多的空间。训练速度:使用JAX急切模式,约为35tokens/秒。内存效率:稳定在约70%左右。扩展性:在8张GPU上,使用JAX的扩展性接近线性。
由于硬件和显存的限制,我们无法运行JIT编译版本的405B模型,整个训练过程是在JAX的急切模式下执行的,因此还有很大的进步空间。
下图中显示了在一次微调训练步骤中,8张GPU的显存利用率和rocm-smi输出:
GPU利用率:
训练设置
将LLaMA3.1从PyTorch移植到JAX
此前,NikhilSonti分享过如何将LLaMA3.1从PyTorch移植到JAX。他指出,目前90%的大型语言模型(LLM)都运行在NVIDIAGPU上,但实际上还有一些同样强大且性价比更高的替代方案。例如,在GoogleTPU上训练和部署Llama3.1的成本比NVIDIAGPU低约30%。
然而,支持非NVIDIA硬件的开发工具较为匮乏。Sonti最初尝试使用PyTorchXLA在TPU上训练Llama3.1,但过程并不顺利。XLA与PyTorch的集成不够完善,缺少一些关键的库(如bitsandbytes无法正常运行),同时还遇到了一些难以解决的HuggingFace错误。
为此,他决定调整策略,将Llama3.1从PyTorch移植到JAX,成功解决了这些问题。Sonti还录制了详细的教程视频,并开源了所有代码:
方法演示:https://dub.sh/felafax-demo代码仓库:https://github.com/felafax/felafax
加载模型,并把模型参数分片
处理像LLaMA405B这样的超大模型,需要在多个设备之间高效地进行参数分片。以下是如何通过JAX实现这一点的。
在JAX中进行参数分片
为了将巨大的LLaMA405B模型高效地分布到8张AMDGPU上,需要使用JAX的设备网格(devicemesh)功能。
部署代码:https://github.com/felafax/felafax/blob/e2a96a0e207e1dc70effde099fe33a9e42a7d5cb/llama3_jax/trainer_engine/jax_utils.py#L69
JAX的设备网格可以帮助我们把可用的设备组织成一个网格,让我们可以指定如何把模型的参数和计算分配到不同的GPU上。
在本文的设置中,需要创建一个形状为(1,8,1)的网格,并将轴分别命名为数据并行(dp)、全分片数据并行(fsdp)和模型并行(mp)。然后,为模型的每个张量定义特定的分片规则,指定这些维度如何沿着这些网格轴进行分片。
DEVICES=jax.devices()
DEVICE_COUNT=len(DEVICES)
DEVICE_MESH=mesh_utils.create_device_mesh((1,8,1))
MESH=Mesh(devices=DEVICE_MESH,axis_names=(/"dp/",/"fsdp/",/"mp/"))
可视化分片
可以使用以下代码来可视化分片结果,从而方便地验证分片规则是否按预期应用。
jax.debug.visualize_array_sharding
分片规则
模型不同组件的分片规则如下所示:
参数如何分片:
参数要在8个GPU之间分配。例如,LMhead(lm_head/kernel)张量有两个轴,按照PS(/"fsdp/",/"mp/")进行分片。在本例中是8和1,因此可以看到该张量在第一个轴上沿着8个GPU被拆分。
Non-Replicated参数:
没有任何分片规范的参数会在所有设备上进行复制。例如,层归一化(attention_norm/kernel和ffn_norm/kernel)没有设置分片规范,是PS(None)。
应用分片函数
在加载模型时,使用以下分片函数逐步对模型权重进行分片:
defmake_shard_and_gather_fns(partition_specs):
defmake_shard_fn(partition_spec):
out_sharding=NamedSharding(mesh,partition_spec)
defshard_fn(tensor):
returnjax.device_put(tensor,out_sharding).block_until_ready()
returnshard_fn
shard_fns=jax.tree_util.tree_map(make_shard_fn,partition_specs)
returnshard_fns
#Createshardfunctionsbasedonpartitioningrules
shard_fns=make_shard_and_gather_fns(partitioning_rules)
这使得我们能够将每个参数放置在指定的设备上,并按照设定的分片进行处理。
分片训练Batch
最初,训练Batch是正常创建的,但在输入模型之前,需要按照下面的代码在GPU上进行分片:
train_batch=jax.device_put(train_batch,
NamedSharding(self.mesh,PS(/"dp/",/"fsdp/")))
在这里,我们指定训练Batch应该在/"dp/"和/"fsdp/"轴上进行分片,在本例中分别对应于被分成1和8份,如果把结果可视化出来,如下所示:
分片前:
在调用jax.device_put之后:
加入LoRA
LoRA通过将权重更新分解为低秩矩阵,减少了可训练参数的数量,这对于微调大型模型特别有效。以下是在AMDGPU上微调Llama3.1-405的LoRA的要点:
将LoRA参数(lora_a和lora_b)与主模型参数分开。使用jax.lax.stop_gradient(kernel)来防止对主模型权重的更新。使用lax.dot_general进行快速、精确控制的矩阵运算。LoRA输出在添加到主输出之前会被缩放为(self.lora_alpha/self.lora_rank)。
LoRADense层
在此设定一个自定义的LoRADense层,该层集成了LoRA参数:
classLoRADense(nn.Module):
features:int
lora_rank:int=8
lora_alpha:float=16.0
@nn.compact
def__call__(self,inputs:Any)->Any:
#Originalkernelparameter(frozen)
kernel=self.param(/'kernel/',...)
y=lax.dot_general(inputs,jax.lax.stop_gradient(kernel),...)
#LoRAparameters(trainable)
lora_a=self.variable(/'lora_params/',/'lora_a/',...,...)
lora_b=self.variable(/'lora_params/',/'lora_b/',...,...)
#ComputeLoRAoutput
lora_output=lax.dot_general(inputs,lora_a.value,...)
lora_output=lax.dot_general(lora_output,lora_b.value,...)
#CombineoriginaloutputwithLoRAmodifications
y+=(self.lora_alpha/self.lora_rank)*lora_output
returny.astype(self.dtype)
分片LoRA参数
为了高效地在设备之间分配LoRA参数,我们也通过JAX设定了分片规则,这确保了LoRA参数与主模型参数的分片一致,优化了内存使用和计算效率。
LoRAAmatrices(lora_a)
LoRAA矩阵(lora_a)
分片规则:PS(/"fsdp/",/"mp/")可视化结果:如下图所示,lora_a参数被分片为(8,1),这意味着第一个轴在8个设备上进行分片(/"fsdp/"轴),而第二个轴未进行分片。
LoRAB矩阵(lora_b)
分片规则:PS(/"mp/",/"fsdp/")可视化结果:如下图所示,lora_b参数被分片为(1,8),这意味着第二个轴在8个设备上进行分片(fsdp轴),而第一个轴未进行分片。
这种分片策略优化了参数的分配,减少了通信开销,并在训练过程中增强了并行性。它确保每个设备仅持有一部分LoRA参数,使得大模型如LLaMA405B的高效扩展成为可能。
仅更新LoRA参数
为了优化训练,在微调LLaMA405B模型,只计算LoRA参数的梯度,保持主模型参数不变。这个方法减少了内存使用,并加速了训练,因为只更新较少的参数。可以移步GitHub仓库,查看实现细节。
在训练过程中,每一步都涉及将一批输入数据通过模型进行处理。由于只有LoRA参数是可训练的,因此模型的预测和计算的损失仅依赖于这些参数,然后对LoRA参数进行反向传播。只更新这些参数简化了训练过程,使得在多个GPU上高效微调像LLaMA405B这样的大型模型成为可能。
更多研究细节,请参考原博客。
未经允许不得转载:头条资讯网_今日热点_娱乐才是你关心的时事 » 微调大模型,AMD MI300X就够了!跟着这篇博客微调Llama 3.1 405B