In [1]:
# By Justin Johnson https://github.com/jcjohnson/pytorch-examples/blob/master/nn/dynamic_net.py

import random
import torch
from torch.autograd import Variable

"""
To showcase the power of PyTorch dynamic graphs, we will implement a very strange
model: a fully-connected ReLU network that on each forward pass randomly chooses
a number between 1 and 4 and has that many hidden layers, reusing the same
weights multiple times to compute the innermost hidden layers.
"""

class DynamicNet(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    """
    In the constructor we construct three nn.Linear instances that we will use
    in the forward pass.
    """
    super(DynamicNet, self).__init__()
    self.input_linear = torch.nn.Linear(D_in, H)
    self.middle_linear = torch.nn.Linear(H, H)
    self.output_linear = torch.nn.Linear(H, D_out)

  def forward(self, x, verbose = False):
    """
    For the forward pass of the model, we randomly choose either 0, 1, 2, or 3
    and reuse the middle_linear Module that many times to compute hidden layer
    representations.
    Since each forward pass builds a dynamic computation graph, we can use normal
    Python control-flow operators like loops or conditional statements when
    defining the forward pass of the model.
    Here we also see that it is perfectly safe to reuse the same Module many
    times when defining a computational graph. This is a big improvement from Lua
    Torch, where each Module could be used only once.
    """
    h_relu = self.input_linear(x).clamp(min=0)
    n_layers = random.randint(0, 3)
    if verbose:
        print("The number of layers for this run is", n_layers)
        # print(h_relu)
    for _ in range(n_layers):
        h_relu = self.middle_linear(h_relu).clamp(min=0)
        if verbose:
            pass
            # print(h_relu)
    y_pred = self.output_linear(h_relu)
    return y_pred




# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 10, 1

# Create random Tensors to hold inputs and outputs, and wrap them in Variables
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

# Construct our model by instantiating the class defined above
model = DynamicNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. Training this strange model with
# vanilla stochastic gradient descent is tough, so we use momentum
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
  # Forward pass: Compute predicted y by passing x to the model
  y_pred = model(x)

  # Compute and print loss
  loss = criterion(y_pred, y)
  print(t, loss.data[0])

  # Zero gradients, perform a backward pass, and update the weights.
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
0 66.81979370117188
1 62.70124053955078
2 60.11635971069336
3 59.01012420654297
4 58.9544677734375
5 59.42093276977539
6 58.79389190673828
7 58.585784912109375
8 58.195953369140625
9 58.74108123779297
10 37.2628173828125
11 58.01183319091797
12 56.33346176147461
13 58.70774841308594
14 57.56559371948242
15 31.000629425048828
16 58.64924621582031
17 57.091064453125
18 53.6650505065918
19 56.62508773803711
20 52.271949768066406
21 51.258689880371094
22 58.11512756347656
23 48.78624725341797
24 47.30036163330078
25 57.446319580078125
26 27.24831771850586
27 56.91810607910156
28 26.6308650970459
29 56.387474060058594
30 39.59881591796875
31 38.3515739440918
32 24.11873435974121
33 49.98106002807617
34 49.13752746582031
35 48.09789276123047
36 53.18637466430664
37 21.792524337768555
38 44.621768951416016
39 20.74881362915039
40 29.225149154663086
41 28.5273380279541
42 40.31317901611328
43 26.860628128051758
44 18.201810836791992
45 44.756046295166016
46 35.72251510620117
47 17.067745208740234
48 24.224334716796875
49 38.0965690612793
50 15.208223342895508
51 14.215011596679688
52 12.78608226776123
53 11.216774940490723
54 9.651387214660645
55 22.624841690063477
56 29.58342742919922
57 28.087305068969727
58 27.007118225097656
59 25.461122512817383
60 5.744221210479736
61 22.746383666992188
62 5.4187469482421875
63 5.176504135131836
64 4.796121120452881
65 20.0670223236084
66 16.9276065826416
67 3.784090518951416
68 3.5151898860931396
69 19.448253631591797
70 14.12013053894043
71 3.0414748191833496
72 13.204912185668945
73 15.031462669372559
74 14.320174217224121
75 13.919631004333496
76 12.410355567932129
77 3.318063497543335
78 12.123408317565918
79 11.39098834991455
80 10.128912925720215
81 9.8468599319458
82 2.7535324096679688
83 9.32371997833252
84 7.845295429229736
85 2.6896610260009766
86 7.9634904861450195
87 7.15878438949585
88 6.074477672576904
89 5.638831615447998
90 2.5943500995635986
91 5.994229316711426
92 4.396603584289551
93 5.417651176452637
94 2.371819019317627
95 3.3579347133636475
96 3.0554635524749756
97 4.243412494659424
98 2.146850347518921
99 2.3707642555236816
100 2.2094995975494385
101 2.044818162918091
102 1.9213523864746094
103 1.7610112428665161
104 1.642599105834961
105 1.5309937000274658
106 4.40830659866333
107 3.9048001766204834
108 1.4214330911636353
109 1.5119552612304688
110 2.820049285888672
111 2.5223183631896973
112 2.3679463863372803
113 1.4368292093276978
114 1.2933924198150635
115 1.4679638147354126
116 2.0784947872161865
117 1.1282614469528198
118 1.0589942932128906
119 1.0066310167312622
120 0.929517388343811
121 1.9272671937942505
122 0.9065648317337036
123 1.8333914279937744
124 0.551647961139679
125 0.8437083959579468
126 1.0596048831939697
127 0.9984511137008667
128 1.2395198345184326
129 0.9647164940834045
130 1.1263706684112549
131 0.945015013217926
132 0.9204602837562561
133 0.7124588489532471
134 0.5728455185890198
135 0.742725133895874
136 0.7666377425193787
137 0.9832253456115723
138 0.4449721574783325
139 0.41877296566963196
140 0.880752444267273
141 0.7153960466384888
142 0.38648515939712524
143 0.5919420719146729
144 0.5247396230697632
145 0.5318360924720764
146 0.5055692791938782
147 0.44386735558509827
148 0.3723817765712738
149 0.3217329978942871
150 1.0343881845474243
151 0.8026278614997864
152 0.3920636475086212
153 0.9063559770584106
154 0.8285737633705139
155 0.7181158065795898
156 0.5305062532424927
157 0.7137705087661743
158 0.6732534170150757
159 0.453058660030365
160 0.5484665036201477
161 0.5566059947013855
162 0.5520378351211548
163 0.5390767455101013
164 0.5901235342025757
165 0.48616209626197815
166 0.44589245319366455
167 0.40851542353630066
168 0.5371256470680237
169 0.34321942925453186
170 0.5438053011894226
171 0.6643029451370239
172 0.4899027645587921
173 0.5306777358055115
174 0.3450915515422821
175 0.4662405848503113
176 0.37746721506118774
177 0.3582407534122467
178 0.45264747738838196
179 0.4232679605484009
180 0.41214287281036377
181 0.486092746257782
182 0.3428241014480591
183 0.4091561734676361
184 0.2879551351070404
185 0.4411312937736511
186 0.4592851996421814
187 0.2482815384864807
188 0.23503336310386658
189 0.22032514214515686
190 0.2053459882736206
191 0.5452612042427063
192 0.5448863506317139
193 0.45595213770866394
194 0.19150827825069427
195 0.4050690829753876
196 0.19678698480129242
197 0.37684592604637146
198 0.41085708141326904
199 0.39771324396133423
200 0.21419884264469147
201 0.21105943620204926
202 0.3220767378807068
203 0.20372898876667023
204 0.2034764140844345
205 0.5483980774879456
206 0.5204151272773743
207 0.5123334527015686
208 0.28036436438560486
209 0.4361844062805176
210 0.29462316632270813
211 0.36272314190864563
212 0.32564136385917664
213 0.3111518323421478
214 0.25811678171157837
215 0.29951661825180054
216 0.27996453642845154
217 0.25109007954597473
218 0.6623751521110535
219 0.22834128141403198
220 0.4407519996166229
221 0.4279693067073822
222 0.2005535066127777
223 0.24736712872982025
224 0.1990315467119217
225 0.35882529616355896
226 0.6465440392494202
227 0.24242916703224182
228 0.5288906693458557
229 0.23824182152748108
230 0.31179723143577576
231 0.33051058650016785
232 0.37014445662498474
233 0.2408895492553711
234 0.2620850205421448
235 0.4307149648666382
236 0.4327353537082672
237 0.27127939462661743
238 0.2367619425058365
239 0.233219176530838
240 0.39933857321739197
241 0.285702109336853
242 0.3715745806694031
243 0.2759586274623871
244 0.2981872856616974
245 0.3323499262332916
246 0.2850937247276306
247 0.2832967936992645
248 0.26399165391921997
249 0.23553146421909332
250 0.3557283580303192
251 0.1933974176645279
252 0.23546989262104034
253 0.23416678607463837
254 0.16889789700508118
255 0.22396452724933624
256 0.16430051624774933
257 0.1619638055562973
258 0.21517841517925262
259 0.4318315386772156
260 0.21091057360172272
261 0.16104541718959808
262 0.39927494525909424
263 0.43209004402160645
264 0.3521919250488281
265 0.3229374885559082
266 0.290527880191803
267 0.2959340512752533
268 0.2362525910139084
269 0.3160015046596527
270 0.30099403858184814
271 0.19246762990951538
272 0.2621278166770935
273 0.25224021077156067
274 0.24537186324596405
275 0.2087096869945526
276 0.23929265141487122
277 0.21777750551700592
278 0.29943525791168213
279 0.5697375535964966
280 0.2084498107433319
281 0.20145222544670105
282 0.30349141359329224
283 0.29999101161956787
284 0.3743019700050354
285 0.2286265790462494
286 0.2349010854959488
287 0.33999374508857727
288 0.32145267724990845
289 0.29463598132133484
290 0.28747421503067017
291 0.25220924615859985
292 0.3102751672267914
293 0.2508939504623413
294 0.3097856342792511
295 0.22689297795295715
296 0.2775665521621704
297 0.18858282268047333
298 0.33132699131965637
299 0.22684210538864136
300 0.31310856342315674
301 0.16501733660697937
302 0.30282944440841675
303 0.3225947916507721
304 0.26472964882850647
305 0.18867895007133484
306 0.2510991394519806
307 0.2484002560377121
308 0.23274771869182587
309 0.22487007081508636
310 0.1694023758172989
311 0.20971743762493134
312 0.3816332519054413
313 0.3825567662715912
314 0.25451093912124634
315 0.2457493096590042
316 0.20889650285243988
317 0.21070371568202972
318 0.1822338104248047
319 0.30178597569465637
320 0.2093411087989807
321 0.20800286531448364
322 0.20138019323349
323 0.19063785672187805
324 0.20248575508594513
325 0.31861039996147156
326 0.1654292494058609
327 0.18406464159488678
328 0.33186110854148865
329 0.17261110246181488
330 0.1679345816373825
331 0.3133294880390167
332 0.1525135338306427
333 0.282392293214798
334 0.22486941516399384
335 0.31785258650779724
336 0.1365605592727661
337 0.30196186900138855
338 0.2868657112121582
339 0.2648289203643799
340 0.23894916474819183
341 0.27521654963493347
342 0.19399793446063995
343 0.23445788025856018
344 0.24096214771270752
345 0.3002694547176361
346 0.29534462094306946
347 0.24750134348869324
348 0.2779269814491272
349 0.21165910363197327
350 0.20300737023353577
351 0.18946193158626556
352 0.22736823558807373
353 0.2348562777042389
354 0.22362038493156433
355 0.20733468234539032
356 0.2475854903459549
357 0.24531008303165436
358 0.16650287806987762
359 0.15709242224693298
360 0.18995612859725952
361 0.14228571951389313
362 0.4518071711063385
363 0.134094700217247
364 0.22038592398166656
365 0.21314726769924164
366 0.21017961204051971
367 0.2019832730293274
368 0.15638434886932373
369 0.23072101175785065
370 0.33102738857269287
371 0.17763686180114746
372 0.16285061836242676
373 0.21188728511333466
374 0.20256376266479492
375 0.19034457206726074
376 0.15954066812992096
377 0.16930334270000458
378 0.1606924831867218
379 0.15127529203891754
380 0.24202530086040497
381 0.24615107476711273
382 0.1812075823545456
383 0.13348977267742157
384 0.33498215675354004
385 0.18923747539520264
386 0.1905059516429901
387 0.14313951134681702
388 0.14261679351329803
389 0.24039535224437714
390 0.2601178288459778
391 0.23952099680900574
392 0.2332109808921814
393 0.15874651074409485
394 0.16348396241664886
395 0.15977570414543152
396 0.20166151225566864
397 0.24505296349525452
398 0.18863420188426971
399 0.2372453361749649
400 0.26944699883461
401 0.14017347991466522
402 0.25345370173454285
403 0.1455172896385193
404 0.21483659744262695
405 0.19386689364910126
406 0.16191509366035461
407 0.18499265611171722
408 0.1485700160264969
409 0.21767349541187286
410 0.15393111109733582
411 0.14663881063461304
412 0.1396859586238861
413 0.3340926468372345
414 0.2295195311307907
415 0.21894294023513794
416 0.16530445218086243
417 0.19163279235363007
418 0.17993079125881195
419 0.18443986773490906
420 0.18124625086784363
421 0.27627816796302795
422 0.17593249678611755
423 0.2576260566711426
424 0.16154754161834717
425 0.1570221185684204
426 0.22066833078861237
427 0.20458157360553741
428 0.14121127128601074
429 0.20393750071525574
430 0.19643907248973846
431 0.2502305209636688
432 0.25012803077697754
433 0.16579250991344452
434 0.23394569754600525
435 0.14524543285369873
436 0.2099425047636032
437 0.1960260272026062
438 0.18063636124134064
439 0.1725504845380783
440 0.14977110922336578
441 0.15014778077602386
442 0.14302337169647217
443 0.15651850402355194
444 0.3694646954536438
445 0.15610407292842865
446 0.34782442450523376
447 0.15768644213676453
448 0.223859965801239
449 0.15451045334339142
450 0.29071494936943054
451 0.20661909878253937
452 0.2609935402870178
453 0.18736699223518372
454 0.17477107048034668
455 0.2173418551683426
456 0.20596471428871155
457 0.1938660591840744
458 0.17813685536384583
459 0.16240327060222626
460 0.2822168469429016
461 0.28906214237213135
462 0.28447872400283813
463 0.2706608176231384
464 0.25063154101371765
465 0.26701921224594116
466 0.18861491978168488
467 0.22946858406066895
468 0.20551608502864838
469 0.1802937388420105
470 0.17821013927459717
471 0.20674364268779755
472 0.1826205849647522
473 0.21166613698005676
474 0.17611679434776306
475 0.14747339487075806
476 0.160769522190094
477 0.14298877120018005
478 0.14171350002288818
479 0.13809143006801605
480 0.282638818025589
481 0.1307322382926941
482 0.12838256359100342
483 0.16531269252300262
484 0.16637392342090607
485 0.16354621946811676
486 0.1577855497598648
487 0.15023595094680786
488 0.12309899926185608
489 0.125056192278862
490 0.2715856730937958
491 0.1269981414079666
492 0.1211034283041954
493 0.20880068838596344
494 0.15389057993888855
495 0.2718507945537567
496 0.17106717824935913
497 0.13586175441741943
498 0.13287678360939026
499 0.12461919337511063
In [2]:
model(x)[1:5]
Out[2]:
Variable containing:
-1.2218
-0.4829
-0.2787
 0.1339
[torch.FloatTensor of size 4x1]
In [3]:
model(x)[1:5] # another run
Out[3]:
Variable containing:
-1.2218
-0.4829
-0.2787
 0.1339
[torch.FloatTensor of size 4x1]
In [4]:
model(x)[1:5]
Out[4]:
Variable containing:
-1.3516
-0.5395
-0.2080
 0.1541
[torch.FloatTensor of size 4x1]

Looks consistent! Let's now try to see what's happening inside

In [5]:
model(x, verbose = True)[1:5]
The number of layers for this run is 3
Out[5]:
Variable containing:
-1.3516
-0.5395
-0.2080
 0.1541
[torch.FloatTensor of size 4x1]
In [6]:
model(x, verbose = True)[1:5]
The number of layers for this run is 1
Out[6]:
Variable containing:
-1.3114
-0.5785
-0.2852
 0.0950
[torch.FloatTensor of size 4x1]
In [7]:
model(x, verbose = True)[1:5]
The number of layers for this run is 0
Out[7]:
Variable containing:
-1.2218
-0.4829
-0.2787
 0.1339
[torch.FloatTensor of size 4x1]
In [8]:
model(x, verbose = True)[1:5]
The number of layers for this run is 0
Out[8]:
Variable containing:
-1.2218
-0.4829
-0.2787
 0.1339
[torch.FloatTensor of size 4x1]
In [9]:
model(x, verbose = True)[1:5]
The number of layers for this run is 2
Out[9]:
Variable containing:
-1.3534
-0.5795
-0.2706
 0.1067
[torch.FloatTensor of size 4x1]

So what's the target?

In [10]:
y[1:5]
Out[10]:
Variable containing:
-1.3277
-0.6187
-0.2690
 0.1163
[torch.FloatTensor of size 4x1]