原型
CLASS torch.nn.Softshrink(lambd=0.5)
参数
定义
SoftShrinkage(x)={x−λ,if x>λx+λ,if x<−λ0,otherwise\text{SoftShrinkage}(x)=\begin{cases} x-\lambda, & \text{if } x > \lambda \\ x+\lambda, & \text{if } x < -\lambda \\ 0, & \text{otherwise} \end{cases} SoftShrinkage(x)=⎩⎨⎧x−λ,x+λ,0,if x>λif x<−λotherwise
图
代码
import torch
import torch.nn as nnm = nn.Softshrink()
input = torch.randn(4)
output = m(input)print("input: ", input)
print("output: ", output)# input: tensor([ 0.9876, -2.0183, -0.7573, -1.7960])
# output: tensor([ 0.4876, -1.5183, -0.2573, -1.2960])
Softshrink — PyTorch 1.13 documentation