AI中的编程 LeNet(pytorch讲解)
Data Representation in PyTorch
The Tensor in PyTorch is actually a multidimensional array
1 | import torch |
Load Data
CiFAR dataset: training – 50k, testing – 10k, 10 categories
1 | import torch |
Implement LeNet
1 | class LeNet(torch.nn.Module): |
Auto-Differentiation in PyTorch
1 | net = Net() |
Optimization in PyTorch
Implement the gradient decent in python
1 | learning_rate = 0.001 |
Parameters: Tensors with gradients
1 | x = torch.rand(5, 5) |
1 | y.grad |