在torch 2.1.0上配置SASA(2022 AAAI)

背景

SASA是2022年AAAI的一篇文章,用于Point-based 3D object detection。SASA所开放的源码基于OpenPCDet v0.3.0的版本,在torch 2.1.0安装这个版本的OpenPCDet,以及运行SASA的代码,都会出现一定的问题。这篇帖子记录了在torch 2.1.0上配置并运行SASA的过程。

安装conda环境

python版本是3.8,PyTorch版本2.1.0,对应CUDA版本用的是11.8。

conda create -n wfmamba python=3.8

conda activate wfmamba

conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia

clone github仓库

git clone https://github.com/blakechen97/SASA.git

安装spconv

pip install spconv-cu118

安装OpenPCDet

pcdet/ops目录替换为OpenPCDet v0.6.0(即最新版本)对应的目录,setup.py替换为OpenPCDet v0.6.0的版本。之后到pcdet/ops/pointnet2/pointnet2_batch/src目录下,将所有该目录下的文件修改为SASA的内容(因为SASA新增了内容),但是要注释掉:

#include <THC/THC.h>

以及:

extern THCState *state;

在新版本的torch中,这个头文件已经被弃用,使用这个头文件会有问题。

下载scikit-imagepip install scikit-image -i https://mirrors.aliyun.com/pypi/simple

运行setup.pypython setup.py develop,完成OpenPCDet的安装。

代码调整

直接run python train.py --cfg_file cfgs/kitti_models/3dssd_sasa.yamlpython train.py --cfg_file cfgs/kitti_models/pointrcnn_sasa.yaml,会报一个关于就地操作的错误(可能在这之前会有ROAD_PLANE的报错,把yaml文件的USE_ROAD_PLANE改成False即可)。原因是代码中某处使用了张量的就地操作,导致梯度的反向传播出现问题。修改文件pcdet/ops/pointnet2/pointnet2_batch/pointnet2_modules.py

line 43: new_features *= idx_cnt_mask 修改为 new_features = new_features * idx_cnt_mask

line 206: new_features *= idx_cnt_mask 修改为 new_features = new_features * idx_cnt_mask

之后就可以跑通代码了。