PyTorch allows you to dynamically define computational graphs. This is done by operating on Variable
s, which wrap PyTorch's Tensor objects.
Here is a simple example:
import torch
from torch.autograd import Variable
import numpy as np
def f(x):
return x**2 + 2 * x
x = Variable(torch.from_numpy(np.array([4.0])), requires_grad=True)
y = f(x)
y.backward()
x.grad.data # 2x + 2 for x = 4
x = Variable(torch.from_numpy(np.array([5.0])), requires_grad=True)
y = f(x)
y.backward()
x.grad.data # 2x + 2 for x = 5
Note that unlike in TensorFlow, we defined the graph on the fly. That is why it was more convenient to define a function in Python: we call the function as part of constructing the graph.