TorchScript. C++でpytorchを動かす
C++でpytorchを動かした際のメモ。
[pytorch]
- tracingとscriptingについて[2]。traceは、example dataを入力しそのときに実行された計算のみを記録する。ifやloopがある場合は、scriptを用いる。
- torch.nn.ModuleListにアクセスする際は、indexを指定できない。"for module in modules"で呼び出す。
https://github.com/pytorch/pytorch/issues/16123
上のURLに加え、modulelistを__constants__として登録しておく必要がある。
https://pytorch.org/docs/stable/jit.html#for-loops-over-constant-nn-modulelist
ネットワークが複雑な場合、上のURL必読。
- methodをコンパイルしたい場合には@torch.jit.export、したくない場合には@torch.jit.ignoreでデコレート。
- sub tensorにアクセスする際、torch.splitを用いてtuple of tensorsに変換しておく。逆の場合は、"var: List[torch.Tensor] = []"に対してappend & stack。
- TorchScriptの全体の流れ (複数入力に対する対処)
inputs_nt = namedtuple('input_nt', ['arg1', 'arg2'])
inputs = inputs_nt(arg1, arg2)
traced_module = torch.jit.trace_module(module, {"forward": tuple(inputs)})
traced_module.save("/path/to/save/traced/module");
https://github.com/pytorch/pytorch/issues/16453
[C++]
- tensor::Tensorの値は、coutで簡単に確認できる。
- (tensor::Tensor dataの)中のデータがポインタやvectorで欲しい場合は
float* ptr = data.data_ptr<float>();
CV系は門外漢(https://discuss.pytorch.org/t/libtorch-c-convert-a-tensor-to-cv-mat-single-channel/47701/2)。
- std::vector(v)からtensor::Tensor(t)へ
tensor::Tensor t = torch::from_blob(v.data(), {b /* batch size */, t /* seq len */}, at::kFloat).clone();
https://pytorch.org/cppdocs/notes/tensor_creation.html
https://discuss.pytorch.org/t/can-i-initialize-tensor-from-std-vector-in-libtorch/33236
- pytorch側への入力が複数tensorの場合、以下のようにする。
std::vector<torch::jit::IValue> inputs;
inputs.push_back(input_tensor1);
inputs.push_back(input_tensor2);
torch::Tensor output = module->forward(inputs).toTensor();
https://github.com/pytorch/pytorch/issues/18337
- pytorch側が複数のtensorを返す場合、それぞれのtensorには以下のようにアクセス。
auto outputs = module->forward(inputs).toTuple();
torch::Tensor out1 = outputs->elements()[0].toTensor();
torch::Tensor out2 = outputs->elements()[1].toTensor();
https://discuss.pytorch.org/t/difference-between-torch-tensor-and-at-tensor/35806
https://github.com/pytorch/pytorch/issues/13638
[その他]
- 保存したファイルをunzipすると、変換後のコードなど中身を確認できる。
[公式]
document of TorchScript
[1] https://pytorch.org/docs/stable/jit.html
TorchScriptのチュートリアル。pytorch codeと対応するtraced code。
[2] https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
pytorch codeをC++で実行する際の、全体の大雑把な流れ。
[3] https://brsoff.github.io/tutorials/advanced/cpp_export.html
[4] https://pytorch.org/cppdocs/
Custom C++ and CUDA Extensions
[5] https://pytorch.org/tutorials/advanced/cpp_extension.html