拟合y = 2x + 365
,展示loss逐渐变小的过程。
本质是一个MLP,y ={\rm FC}({\rm ReLU}({\rm FC}(x)))
,使用Adam优化器。
import numpy as np
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(in_features=192, out_features=768)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(in_features=768, out_features=192)
def forward(self, x):
x = self.fc(x)
x = self.relu(x)
x = self.fc2(x)
return x
def huber_loss(preds, labels, delta=1.0):
residual = torch.abs(preds - labels)
condition = torch.le(residual, delta)
small_res = 0.5 * torch.square(residual)
large_res = delta * residual - 0.5 * delta * delta
return torch.mean(torch.where(condition, small_res, large_res))
def func(x: torch.Tensor) -> torch.FloatTensor:
"""
f(x) = 2x + 365
"""
return x * 2 + 365
if __name__ == '__main__':
train_x = [
torch.tensor([_ for _ in range(0, 384, 2)], dtype=torch.float32),
torch.tensor([_ for _ in range(382, -1, -2)], dtype=torch.float32),
torch.tensor([_ for _ in range(384, 768, 2)], dtype=torch.float32),
torch.tensor([_ for _ in range(767, 383, -2)], dtype=torch.float32),
torch.tensor([_ for _ in range(768, 384, -2)], dtype=torch.float32),
torch.tensor([_ for _ in range(383, 0, -2)], dtype=torch.float32),
torch.tensor([_ for _ in range(385, 2, -2)], dtype=torch.float32),
]
train_y = [
func(x) for x in train_x
]
test_x = torch.tensor(
[_ for _ in range(1, 384, 2)], dtype=torch.float32
)
test_y = func(test_x)
mynet = MyModel()
mynet.train()
optimizer = torch.optim.Adam(params=mynet.parameters(), lr=0.0001) # Adam优化器
for epoch in range(2000):
losses = []
for x, y in zip(train_x, train_y):
optimizer.zero_grad() # 将梯度初始化为零
_y = mynet(x) # 前向传播求出预测的值
loss = huber_loss(_y, y) # 求loss
loss.backward() # 反向传播求梯度
optimizer.step() # 更新所有参数
losses.append(loss.detach().numpy())
print("Epoch {}, Loss: {}".format(epoch, np.mean(losses)))
with torch.no_grad():
mynet.eval()
y = mynet(test_x)
loss = huber_loss(y, test_y)
print("Test loss: {}".format(loss))
效果如下
C:\Users\60946\.conda\envs\libcity\python.exe D:\Code\2023\ST-RetNet\test\test4.py
Epoch 0, Loss: 1057.2236328125
Epoch 1, Loss: 1002.7907104492188
Epoch 2, Loss: 939.7447509765625
Epoch 3, Loss: 859.8356323242188
Epoch 4, Loss: 756.3015747070312
Epoch 5, Loss: 624.2710571289062
...
Epoch 20, Loss: 175.79489135742188
...
Epoch 50, Loss: 101.21642303466797
...
Epoch 196, Loss: 18.970491409301758
Epoch 197, Loss: 14.733353614807129
Epoch 198, Loss: 19.96961784362793
Epoch 199, Loss: 14.879599571228027
Epoch 200, Loss: 18.989852905273438
...
Epoch 1995, Loss: 9.355721473693848
Epoch 1996, Loss: 14.7473783493042
Epoch 1997, Loss: 9.928330421447754
Epoch 1998, Loss: 14.550915718078613
Epoch 1999, Loss: 9.781781196594238
Test loss: 7.365930080413818