diff --git a/beginner_source/basics/autogradqs_tutorial.py b/beginner_source/basics/autogradqs_tutorial.py index 8eff127ddee..671ed67c817 100644 --- a/beginner_source/basics/autogradqs_tutorial.py +++ b/beginner_source/basics/autogradqs_tutorial.py @@ -32,7 +32,7 @@ y = torch.zeros(3) # expected output w = torch.randn(5, 3, requires_grad=True) b = torch.randn(3, requires_grad=True) -z = torch.matmul(x, w)+b +z = torch.matmul(x, w) + b loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) @@ -133,7 +133,8 @@ # - To mark some parameters in your neural network as **frozen parameters**. # - To **speed up computations** when you are only doing forward pass, because computations on tensors that do # not track gradients would be more efficient. - +# For additional reference, you can view the autograd mechanics +# documentation:https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation ###################################################################### @@ -160,6 +161,39 @@ # - accumulates them in the respective tensor’s ``.grad`` attribute # - using the chain rule, propagates all the way to the leaf tensors. # +# We can also visualize the computational graph by the following 2 methods: +# +# 1. TORCH_LOGS="+autograd" +# By setting the TORCH_LOGS="+autograd" environment variable, we can enable runtime autograd logs for debugging. +# +# We can perform the logging in the following manner: +# TORCH_LOGS="+autograd" python test.py +# +# 2. Torchviz +# Torchviz is a package to render the computational graph visually. +# +# We can generate an image for the computational graph in the example given below: +# +# import torch +# from torch import nn +# from torchviz import make_dot +# +# model = nn.Sequential( +# nn.Linear(8, 16), +# nn.ReLU(), +# nn.Linear(16, 1) +# ) + +# x = torch.randn(1, 8, requires_grad=True) +# y = model(x).mean() + +# log the internal operations using torchviz +# import os +# os.environ['TORCH_LOGS'] = "+autograd" + +# dot = make_dot(y, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) +# dot.render('simple_graph', format='png') +# # .. note:: # **DAGs are dynamic in PyTorch** # An important thing to note is that the graph is recreated from scratch; after each