校园春色亚洲色图_亚洲视频分类_中文字幕精品一区二区精品_麻豆一区区三区四区产品精品蜜桃

主頁(yè) > 知識(shí)庫(kù) > pytorch自定義不可導(dǎo)激活函數(shù)的操作

pytorch自定義不可導(dǎo)激活函數(shù)的操作

熱門(mén)標(biāo)簽:商家地圖標(biāo)注海報(bào) 打電話機(jī)器人營(yíng)銷 騰訊地圖標(biāo)注沒(méi)法顯示 南陽(yáng)打電話機(jī)器人 孝感營(yíng)銷電話機(jī)器人效果怎么樣 地圖標(biāo)注自己和別人標(biāo)注區(qū)別 海外網(wǎng)吧地圖標(biāo)注注冊(cè) 聊城語(yǔ)音外呼系統(tǒng) ai電銷機(jī)器人的優(yōu)勢(shì)

pytorch自定義不可導(dǎo)激活函數(shù)

今天自定義不可導(dǎo)函數(shù)的時(shí)候遇到了一個(gè)大坑。

首先我需要自定義一個(gè)函數(shù):sign_f

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs  0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_-1.] = 0
        return grad_output

然后我需要把它封裝為一個(gè)module 類型,就像 nn.Conv2d 模塊 封裝 f.conv2d 一樣,于是

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
	# 我需要的module
    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        
    def forward(self, inputs):
    	# 使用自定義函數(shù)
        outs = sign_f(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs  0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_-1.] = 0
        return grad_output

結(jié)果報(bào)錯(cuò)

TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'

我試了半天,發(fā)現(xiàn)自定義函數(shù)后面要加 apply ,詳細(xì)見(jiàn)下面

import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):

    def __init__(self, *kargs, **kwargs):
        super(sign_, self).__init__(*kargs, **kwargs)
        self.r = sign_f.apply ### -----注意此處
        
    def forward(self, inputs):
        outs = self.r(inputs)
        return outs

class sign_f(Function):
    @staticmethod
    def forward(ctx, inputs):
        output = inputs.new(inputs.size())
        output[inputs >= 0.] = 1
        output[inputs  0.] = -1
        ctx.save_for_backward(inputs)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, = ctx.saved_tensors
        grad_output[input_>1.] = 0
        grad_output[input_-1.] = 0
        return grad_output

問(wèn)題解決了!

PyTorch自定義帶學(xué)習(xí)參數(shù)的激活函數(shù)(如sigmoid)

有的時(shí)候我們需要給損失函數(shù)設(shè)一個(gè)超參數(shù)但是又不想設(shè)固定閾值想和網(wǎng)絡(luò)一起自動(dòng)學(xué)習(xí),例如給Sigmoid一個(gè)參數(shù)alpha進(jìn)行調(diào)節(jié)

函數(shù)如下:

import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))

驗(yàn)證和Sigmoid的一致性

class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()
    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
   
Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
        [0.4379, 0.1828, 0.4629],
        [0.4302, 0.1358, 0.4180]])

print(Sigmoid(input))
print(LearnSigmoid(input))

輸出結(jié)果

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])

tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=MulBackward0>)

驗(yàn)證權(quán)重是不是會(huì)更新

import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
    def __init__(self, ):
        super(LearnableSigmoid, self).__init__()
        self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)

        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.fill_(1.0)
        
    def forward(self, input):
        return 1/(1 +  torch.exp(-self.weight*input))
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()       
        self.LSigmoid = LearnableSigmoid()
    def forward(self, x):                
        x = self.LSigmoid(x)
        return x

net = Net()  
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)

for i in range(2):
    optimizer.zero_grad()     
    output = net(input_data)   
    loss = criterion(output, target)
    loss.backward()             
    optimizer.step()           
    print(list(net.parameters()))

輸出結(jié)果

tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]

會(huì)更新~

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

您可能感興趣的文章:
  • pytorch方法測(cè)試——激活函數(shù)(ReLU)詳解
  • PyTorch中常用的激活函數(shù)的方法示例
  • Pytorch 實(shí)現(xiàn)自定義參數(shù)層的例子

標(biāo)簽:南寧 迪慶 牡丹江 撫州 聊城 楊凌 揚(yáng)州 六盤(pán)水

巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《pytorch自定義不可導(dǎo)激活函數(shù)的操作》,本文關(guān)鍵詞  pytorch,自定義,不,可導(dǎo),激活,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。
  • 相關(guān)文章
  • 下面列出與本文章《pytorch自定義不可導(dǎo)激活函數(shù)的操作》相關(guān)的同類信息!
  • 本頁(yè)收集關(guān)于pytorch自定義不可導(dǎo)激活函數(shù)的操作的相關(guān)信息資訊供網(wǎng)民參考!
  • 推薦文章
    主站蜘蛛池模板: 象山县| 商水县| 普格县| 嘉荫县| 江川县| 桓台县| 蕲春县| 岑巩县| 德化县| 明溪县| 金秀| 阿拉善盟| 桂阳县| 界首市| 嵊州市| 洪洞县| 西畴县| 镇原县| 台东市| 安溪县| 扎鲁特旗| 道真| 宜兴市| 呼和浩特市| 沾化县| 广丰县| 万盛区| 龙南县| 定南县| 新和县| 深水埗区| 镇安县| 红原县| 讷河市| 洛南县| 马公市| 体育| 离岛区| 永兴县| 建宁县| 抚松县|