假设我们有如下的代码

#include <torch/torch.h>

int main() {
    auto param = torch::nn::Linear(2,2)->bias;
    auto input_param = torch::nn::Linear(2,2)->bias; 
    param = input_param;
}

此时我们会常规的认为这样的赋值会直接修改param的对象,然而事实是并不会。查看发现torch::Tensor下有一个copy_函数

于是修改为

#include <torch/torch.h>

int main() {
    auto param = torch::nn::Linear(2,2)->bias;
    auto input_param = torch::nn::Linear(2,2)->bias; 
    param.copy_(input_param);
}

还是有报错

what():  a leaf Variable that requires grad has been used in an in-place operation. (check_inplace at ../../torch/csrc/autograd/VariableTypeUtils.h:46)

含义大概是说自动求导的Tensor不允许直接修改(可能会破坏自动求导?

需要修改为

int main() {
    auto param = torch::nn::Linear(2,2)->bias;
    auto input_param = torch::nn::Linear(2,2)->bias; 
    param.data().copy_(input_param.clone());
}

这里加上的clone()可以不需要(为了安全点就加上了

需要注意的是对tensor的赋值必须使用data()获取内部数据才行

说点什么
支持Markdown语法
好耶,沙发还空着ヾ(≧▽≦*)o
Loading...