Antkillerfarm Hacking V7.0

OpenXLA

2024-03-13

OpenXLA

2022.10 Google又创建了一个新的OpenXLA项目,旨在将XLA从TF中解耦。

官网:

https://github.com/openxla/xla

XRT & PJRT

PJRT:Pretty much Just another RunTime

TF的代码中有如下两个文件夹:

tensorflow/compiler/xla/xrt

tensorflow/compiler/xla/pjrt

XRT & PJRT的作用是:为其他框架如Pytorch/JAX提供生成XLA IR,并执行的能力。

PJRT是XRT的升级版。

参考:

https://github.com/openxla/xla/blob/main/xla/pjrt/c/docs/pjrt_integration_guide.md

Pytorch XLA

Pytorch官方提供了如下项目支持XLA:

https://github.com/pytorch/xla

这个项目实际上是Google来维护的。

粗看了一下,都是些上层的代码,底层直接调用TF的实现。所以如果目标硬件已经接入TF XLA接口的话,理论上不需要修改就可以跑pytorch。

文档:

https://pytorch.org/xla/master/

PyTorch on XLA Devices

示例:

https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py

编译:

git clone --recursive https://github.com/pytorch/xla.git
cd xla

# modify bazel/dependencies.bzl to config pytorch:
# PYTORCH_LOCAL_DIR = "../pytorch"

# modify WORKSPACE to config openxla:
# by commenting out the http_archive above and uncommenting the following:
# local_repository(
#    name = "xla",
#    path = "/path/to/openxla",
# )

export XLA_CUDA=0
export BUNDLE_LIBTPU=0
python setup.py install

export PJRT_DEVICE=NPU
export NPU_LIBRARY_PATH=<work dir>/pytorch_xla/xla/build/temp.linux-x86_64-cpython-312/bazel-xla/bazel-out/k8-opt/bin/external/xla/xla/pjrt/plugin/libnpu.so

python3 ./torch_xla_lenet.py

Backend的参考实现:

Intel XPU:

https://github.com/intel/intel-extension-for-openxla

IREE:

https://github.com/openxla/openxla-pjrt-plugin


Mesh API:

https://pytorch.org/blog/pytorch-xla-spmd

PyTorch/XLA SPMD: Scale Up Model Training and Serving with Automatic Parallelization

https://www.openteams.com/large-scale-training-of-hugging-face-transformers-on-tpus-with-pytorch-xla-fsdp

Large Scale Training of Hugging Face Transformers on TPUs With PyTorch/XLA FSDP


有时候conda环境里使用的gcc版本,和本地的gcc版本有差异,会导致编译产生的py库,有一些问题:

https://blog.csdn.net/weixin_39379635/article/details/129159713

如何解决version `GLIBCXX_3.4.29‘ not found的问题

类似的,还有库找不到的问题,一般用LD_LIBRARY_PATH解决之。


https://blog.csdn.net/qq_40329272/article/details/111801695

undefined symbol:_ZN5torchXXXX


打印网络信息:

torch_xla/torch_xla/core/dynamo_bridge.py:

def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):

To print a tabular representation of the graph, use:

xla_model.graph.print_tabular()

To get a SVG visualization of the graph, use:

from torch.fx.passes.graph_drawer import FxGraphDrawer
drawer = FxGraphDrawer(xla_model, "model_name")
with open(f"svg/save/dir/{drawer._name}.svg", mode="wb") as f:
    f.write(drawer.get_dot_graph().create_svg())

custom op:

XlaDeviceType hw_type = static_cast<XlaDeviceType>(GetCurrentDevice().type());
if (hw_type == XlaDeviceType::TPU)
{
  resized =xla::CustomCall(...);
}

动态注册plugin:

torch_xla/experimental/plugins.py

    npu_plugin = NpuPlugin()
    xp.use_dynamic_plugins()
    xp.register_plugin("npu", npu_plugin)
    xr.set_device_type("npu")

    device_type = xr.device_type()
    assert device_type == "npu"

XLA_USE_32BIT_LONG

squash_64bit_types


options的传递链:

// torch_xla
torch_xla::runtime::GetComputationClient()
client = std::make_unique<PjRtComputationClient>();
options.allowed_devices = allowed_devices;
client_ = std::move(xla::GetStreamExecutorGpuClient(options).value());

// openxla
GetGpuXlaClient(options.platform_name, options.allowed_devices)
LocalClientOptions options;
ClientLibrary::GetOrCreateLocalClient(options);
ServiceOptions service_options;
LocalService::NewService(service_options)
BackendOptions backend_options;
Backend::CreateBackend(backend_options)
PlatformUtil::GetStreamExecutors(platform, options.allowed_devices())

参考:

https://pytorch.org/blog/pytorch-2.0-xla/

PyTorch 2.0 & XLA—The Latest Cutting Edge Features

https://huggingface.co/blog/pytorch-xla

Hugging Face on PyTorch / XLA TPUs: Faster and cheaper training

https://pytorch.org/tutorials/recipes/intel_extension_for_pytorch.html

INTEL EXTENSION FOR PYTORCH

https://pytorch.org/blog/path-achieve-low-inference-latency/

The Path to Achieve Ultra-Low Inference Latency With LLaMA 65B on PyTorch/XLA

Quantization

pytorch的Quantization主要使用PT2E(PyTorch 2 Export)接口。

demo:

torch_xla/test/stablehlo/test_pt2e_qdq.py

convert_pt2e有个参数叫做fold_quantize,目前为True的情况下,还有一些问题:

https://github.com/pytorch/xla/issues/6567

pytorch IR -> XLA IR -> stablehlo IR,其中的第二步存在一些问题。

XLA_STABLEHLO_COMPILE

  %custom-call.101 = f32[1,3,224,224]{3,2,1,0} custom-call(f32[1,3,224,224]{3,2,1,0} %p12.100), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[0.00391965],zero_point=[-128],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
  %custom-call.102 = f32[1,3,224,224]{3,2,1,0} custom-call(f32[1,3,224,224]{3,2,1,0} %custom-call.101), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[0.00391965],zero_point=[-128],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}

参考:

https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html

PyTorch 2 Export Post Training Quantization

https://pytorch.org/tutorials/prototype/pt2e_quantizer.html

How to Write a Quantizer for PyTorch 2 Export Quantization

Profile

C++层面:

使用tsl::profiler::TraceMe,与TF XLA一致。

python层面:

使用torch_xla.debug.profiler类。

demo:

torch_xla/test/pjrt/test_profiler.py

torch_xla/test/test_profile_mp_mnist.py

两个示例展示了两种不同的tracer用法。

FSDP for torch_xla

FSDP的基本概念,参见《并行 & 框架 & 优化(四)》。

分布式,包括FSDP的测试用例主要在:

  • torch_xla/test/test_train_mp_mnist.py
  • torch_xla/test/test_train_mp_mnist_zero1.py
  • torch_xla/test/test_train_mp_mnist_fsdp_with_ckpt.py

min_num_params: FSDP默认自动包装的最小参数量。 cpu_offload: 是否将参数和梯度卸载到CPU。

XlaFullyShardedDataParallel
self._shard_parameters_(params_to_shard)
torch_xla._XLAC._replace_xla_tensor(p, p.new_zeros(1))
shard_data = self._get_shard(p)
_flatten_and_pad_to_world_size
self._fsdp_wrapped_module: nn.Module = XlaFlattenParamsWrapper

参考:

https://pytorch.org/blog/scaling-pytorch-models-on-cloud-tpus-with-fsdp/

Scaling PyTorch models on Cloud TPUs with FSDP

https://zhuanlan.zhihu.com/p/616858352

一些LLMs的省内存方法


FSDP的计算图会随时根据需要,创建/删除weight,如果不加特殊处理的话,后续优化环节的subexpression elimination (CSE) ,会将这些操作优化掉,所以XLA引入了OptimizationBarrier算子来阻止这个优化。

JAX

一款由谷歌团队打造(非官方发布),用于从纯Python和Numpy机器学习程序中生成高性能加速器(accelerator)代码,且特定于域的跟踪JIT编译器。

代码:

https://github.com/google/jax

文档:

https://jax.readthedocs.io/en/latest/

JAX的底层也是基于XLA的。

JAX并不是TF的替代品,它缺失了一些数据准备和调度的功能。这些功能一般可用haiku/flax提供。

RLax:这是一个基于Jax的强化学习库。

参考:

https://mp.weixin.qq.com/s/IMMdbF33ZHEz7N_XwgIhHA

试试谷歌这个新工具:说不定比TensorFlow还好用!

https://mp.weixin.qq.com/s/tZ3yWQ9–l9e81UqoUoWIQ

要替代TensorFlow?谷歌开源机器学习库JAX

https://mp.weixin.qq.com/s/eaYwiV2LZNRwzPEeOA1XFg

新星JAX :双挑TensorFlow和PyTorch!有望担纲Google主要科学计算库和神经网络库

https://mp.weixin.qq.com/s/NhMbr_niHjSaqh2azuSaog

只知道TF和PyTorch还不够,快来看看怎么从PyTorch转向自动微分神器JAX

https://wzzju.github.io/jax/xla/2022/02/17/jax-cpp/

JAX程序转HLO执行

https://zhuanlan.zhihu.com/p/532504225

面向PyTorch用户的JAX简易教程(1): JAX介绍

https://zhuanlan.zhihu.com/p/544216783

面向PyTorch用户的JAX简易教程(2): 如何训练一个神经网络

Fork me on GitHub