zero.training.learn¶
-
zero.training.learn(model, optimizer, loss_fn, step, batch, star=False)[source]¶ The “default” training step.
The function does the following:
Switches the model to the training mode, sets its gradients to zero.
Performs the call
step(batch)orstep(*batch)The output from the previous step is passed to
loss_fntorch.Tensor.backwardis applied to the obtained loss tensor.The optimization step is performed.
Returns the loss’s value (float) and
step’s output
- Parameters
model (torch.nn.modules.module.Module) – the model to train
optimizer (torch.optim.optimizer.Optimizer) – the optimizer for
modelloss_fn (Callable[.., torch.Tensor]) – the function that takes
step’s output as input and returns a loss tensorstep (Callable[T, Any]) – the function that takes
batchas input and produces input forloss_fnbatch (T) – input for
stepstar (bool) – if True, then the output of
stepis unpacked when passed toloss_fn, i.e.loss_fn(*step_output)is performed instead ofloss_fn(step_output)
- Returns
(loss_value, step_output)
- Return type
Tuple[float, Any]
Note
After the function returns:
model’s gradients (caused by backward) are preservedmodel’s state (training or not) is undefined
Examples
model = ... optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) loss_fn = torch.nn.MSELoss() def step(batch): X, y = batch return model(X), y for epoch in epoches: for batch in batches: learn(model, optimizer, loss_fn, step, batch, True)
model = ... optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) def step(batch): X, y = batch return {'y_pred': model(X), 'y': y} loss_fn = lambda out: torch.nn.functional.mse_loss(out['y_pred'], out['y']) for epoch in epoches: for batch in batches: learn(model, optimizer, loss_fn, step, batch)