背景
SASA
是2022年AAAI的一篇文章,用于Point-based 3D object detection。SASA
所开放的源码基于OpenPCDet v0.3.0
的版本,在torch 2.1.0
安装这个版本的OpenPCDet
,以及运行SASA
的代码,都会出现一定的问题。这篇帖子记录了在torch 2.1.0上
配置并运行SASA
的过程。
安装conda环境
python版本是3.8,PyTorch版本2.1.0,对应CUDA版本用的是11.8。
conda create -n wfmamba python=3.8
conda activate wfmamba
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia
clone github仓库
git clone https://github.com/blakechen97/SASA.git
安装spconv
pip install spconv-cu118
安装OpenPCDet
将pcdet/ops
目录替换为OpenPCDet v0.6.0
(即最新版本)对应的目录,setup.py
替换为OpenPCDet v0.6.0
的版本。之后到pcdet/ops/pointnet2/pointnet2_batch/src
目录下,将所有该目录下的文件修改为SASA
的内容(因为SASA新增了内容),但是要注释掉:
#include <THC/THC.h>
以及:
extern THCState *state;
在新版本的torch
中,这个头文件已经被弃用,使用这个头文件会有问题。
下载scikit-image
:pip install scikit-image -i https://mirrors.aliyun.com/pypi/simple
运行setup.py
:python setup.py develop
,完成OpenPCDet
的安装。
代码调整
直接run python train.py --cfg_file cfgs/kitti_models/3dssd_sasa.yaml
或 python train.py --cfg_file cfgs/kitti_models/pointrcnn_sasa.yaml
,会报一个关于就地操作的错误(可能在这之前会有ROAD_PLANE
的报错,把yaml
文件的USE_ROAD_PLANE
改成False
即可)。原因是代码中某处使用了张量的就地操作,导致梯度的反向传播出现问题。修改文件pcdet/ops/pointnet2/pointnet2_batch/pointnet2_modules.py
:
line 43
: new_features *= idx_cnt_mask
修改为 new_features = new_features * idx_cnt_mask
;
line 206
: new_features *= idx_cnt_mask
修改为 new_features = new_features * idx_cnt_mask
;
之后就可以跑通代码了。