2022.10 Google又创建了一个新的OpenXLA项目,旨在将XLA从TF中解耦。
官网:
https://github.com/openxla/xla
PJRT:Pretty much Just another RunTime
TF的代码中有如下两个文件夹:
tensorflow/compiler/xla/xrt
tensorflow/compiler/xla/pjrt
XRT & PJRT的作用是:为其他框架如Pytorch/JAX提供生成XLA IR,并执行的能力。
PJRT是XRT的升级版。
参考:
https://github.com/openxla/xla/blob/main/xla/pjrt/c/docs/pjrt_integration_guide.md
Pytorch官方提供了如下项目支持XLA:
https://github.com/pytorch/xla
这个项目实际上是Google来维护的。
粗看了一下,都是些上层的代码,底层直接调用TF的实现。所以如果目标硬件已经接入TF XLA接口的话,理论上不需要修改就可以跑pytorch。
文档:
https://pytorch.org/xla/master/
PyTorch on XLA Devices
示例:
https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py
编译:
git clone --recursive https://github.com/pytorch/xla.git
cd xla
# modify bazel/dependencies.bzl to config pytorch:
# PYTORCH_LOCAL_DIR = "../pytorch"
# modify WORKSPACE to config openxla:
# by commenting out the http_archive above and uncommenting the following:
# local_repository(
# name = "xla",
# path = "/path/to/openxla",
# )
export XLA_CUDA=0
export BUNDLE_LIBTPU=0
python setup.py install
export PJRT_DEVICE=NPU
export NPU_LIBRARY_PATH=<work dir>/pytorch_xla/xla/build/temp.linux-x86_64-cpython-312/bazel-xla/bazel-out/k8-opt/bin/external/xla/xla/pjrt/plugin/libnpu.so
python3 ./torch_xla_lenet.py
Backend的参考实现:
Intel XPU:
https://github.com/intel/intel-extension-for-openxla
IREE:
https://github.com/openxla/openxla-pjrt-plugin
Mesh API:
https://pytorch.org/blog/pytorch-xla-spmd
PyTorch/XLA SPMD: Scale Up Model Training and Serving with Automatic Parallelization
https://www.openteams.com/large-scale-training-of-hugging-face-transformers-on-tpus-with-pytorch-xla-fsdp
Large Scale Training of Hugging Face Transformers on TPUs With PyTorch/XLA FSDP
有时候conda环境里使用的gcc版本,和本地的gcc版本有差异,会导致编译产生的py库,有一些问题:
https://blog.csdn.net/weixin_39379635/article/details/129159713
如何解决version `GLIBCXX_3.4.29‘ not found的问题
类似的,还有库找不到的问题,一般用LD_LIBRARY_PATH解决之。
https://blog.csdn.net/qq_40329272/article/details/111801695
undefined symbol:_ZN5torchXXXX
打印网络信息:
torch_xla/torch_xla/core/dynamo_bridge.py:
def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
To print a tabular representation of the graph, use:
xla_model.graph.print_tabular()
To get a SVG visualization of the graph, use:
from torch.fx.passes.graph_drawer import FxGraphDrawer
drawer = FxGraphDrawer(xla_model, "model_name")
with open(f"svg/save/dir/{drawer._name}.svg", mode="wb") as f:
f.write(drawer.get_dot_graph().create_svg())
custom op:
XlaDeviceType hw_type = static_cast<XlaDeviceType>(GetCurrentDevice().type());
if (hw_type == XlaDeviceType::TPU)
{
resized =xla::CustomCall(...);
}
动态注册plugin:
torch_xla/experimental/plugins.py
npu_plugin = NpuPlugin()
xp.use_dynamic_plugins()
xp.register_plugin("npu", npu_plugin)
xr.set_device_type("npu")
device_type = xr.device_type()
assert device_type == "npu"
XLA_USE_32BIT_LONG
squash_64bit_types
options的传递链:
// torch_xla
torch_xla::runtime::GetComputationClient()
client = std::make_unique<PjRtComputationClient>();
options.allowed_devices = allowed_devices;
client_ = std::move(xla::GetStreamExecutorGpuClient(options).value());
// openxla
GetGpuXlaClient(options.platform_name, options.allowed_devices)
LocalClientOptions options;
ClientLibrary::GetOrCreateLocalClient(options);
ServiceOptions service_options;
LocalService::NewService(service_options)
BackendOptions backend_options;
Backend::CreateBackend(backend_options)
PlatformUtil::GetStreamExecutors(platform, options.allowed_devices())
参考:
https://pytorch.org/blog/pytorch-2.0-xla/
PyTorch 2.0 & XLA—The Latest Cutting Edge Features
https://huggingface.co/blog/pytorch-xla
Hugging Face on PyTorch / XLA TPUs: Faster and cheaper training
https://pytorch.org/tutorials/recipes/intel_extension_for_pytorch.html
INTEL EXTENSION FOR PYTORCH
https://pytorch.org/blog/path-achieve-low-inference-latency/
The Path to Achieve Ultra-Low Inference Latency With LLaMA 65B on PyTorch/XLA
pytorch的Quantization主要使用PT2E(PyTorch 2 Export)接口。
demo:
torch_xla/test/stablehlo/test_pt2e_qdq.py
convert_pt2e
有个参数叫做fold_quantize,目前为True的情况下,还有一些问题:
https://github.com/pytorch/xla/issues/6567
pytorch IR -> XLA IR -> stablehlo IR,其中的第二步存在一些问题。
XLA_STABLEHLO_COMPILE
%custom-call.101 = f32[1,3,224,224]{3,2,1,0} custom-call(f32[1,3,224,224]{3,2,1,0} %p12.100), custom_call_target="stablehlo.uniform_quantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[0.00391965],zero_point=[-128],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
%custom-call.102 = f32[1,3,224,224]{3,2,1,0} custom-call(f32[1,3,224,224]{3,2,1,0} %custom-call.101), custom_call_target="stablehlo.uniform_dequantize", api_version=API_VERSION_TYPED_FFI, backend_config={scale=[0.00391965],zero_point=[-128],storage_type=si8,expressed_type=f32,storage_min=-128,storage_max=127}
参考:
https://pytorch.org/tutorials/prototype/pt2e_quant_ptq.html
PyTorch 2 Export Post Training Quantization
https://pytorch.org/tutorials/prototype/pt2e_quantizer.html
How to Write a Quantizer for PyTorch 2 Export Quantization
C++层面:
使用tsl::profiler::TraceMe
,与TF XLA一致。
python层面:
使用torch_xla.debug.profiler
类。
demo:
torch_xla/test/pjrt/test_profiler.py
torch_xla/test/test_profile_mp_mnist.py
两个示例展示了两种不同的tracer用法。
FSDP的基本概念,参见《并行 & 框架 & 优化(四)》。
分布式,包括FSDP的测试用例主要在:
min_num_params: FSDP默认自动包装的最小参数量。 cpu_offload: 是否将参数和梯度卸载到CPU。
XlaFullyShardedDataParallel
self._shard_parameters_(params_to_shard)
torch_xla._XLAC._replace_xla_tensor(p, p.new_zeros(1))
shard_data = self._get_shard(p)
_flatten_and_pad_to_world_size
self._fsdp_wrapped_module: nn.Module = XlaFlattenParamsWrapper
参考:
https://pytorch.org/blog/scaling-pytorch-models-on-cloud-tpus-with-fsdp/
Scaling PyTorch models on Cloud TPUs with FSDP
https://zhuanlan.zhihu.com/p/616858352
一些LLMs的省内存方法
FSDP的计算图会随时根据需要,创建/删除weight,如果不加特殊处理的话,后续优化环节的subexpression elimination (CSE) ,会将这些操作优化掉,所以XLA引入了OptimizationBarrier算子来阻止这个优化。
一款由谷歌团队打造(非官方发布),用于从纯Python和Numpy机器学习程序中生成高性能加速器(accelerator)代码,且特定于域的跟踪JIT编译器。
代码:
https://github.com/google/jax
文档:
https://jax.readthedocs.io/en/latest/
JAX的底层也是基于XLA的。
JAX并不是TF的替代品,它缺失了一些数据准备和调度的功能。这些功能一般可用haiku/flax提供。
RLax:这是一个基于Jax的强化学习库。
参考:
https://mp.weixin.qq.com/s/IMMdbF33ZHEz7N_XwgIhHA
试试谷歌这个新工具:说不定比TensorFlow还好用!
https://mp.weixin.qq.com/s/tZ3yWQ9–l9e81UqoUoWIQ
要替代TensorFlow?谷歌开源机器学习库JAX
https://mp.weixin.qq.com/s/eaYwiV2LZNRwzPEeOA1XFg
新星JAX :双挑TensorFlow和PyTorch!有望担纲Google主要科学计算库和神经网络库
https://mp.weixin.qq.com/s/NhMbr_niHjSaqh2azuSaog
只知道TF和PyTorch还不够,快来看看怎么从PyTorch转向自动微分神器JAX
https://wzzju.github.io/jax/xla/2022/02/17/jax-cpp/
JAX程序转HLO执行
https://zhuanlan.zhihu.com/p/532504225
面向PyTorch用户的JAX简易教程(1): JAX介绍
https://zhuanlan.zhihu.com/p/544216783
面向PyTorch用户的JAX简易教程(2): 如何训练一个神经网络
您的打赏,是对我的鼓励
请访问这里提交评论