The layer is used to make skip-connections inside a torch::nn_sequential() network or between several torch::nn_sequential() networks without unnecessary code complication.

memory_layer(id, shared_env, output = FALSE, add = FALSE, verbose = FALSE)

Arguments

id

A unique id to use as a key in the storage list.

shared_env

A shared environment for all instances of memory_layer where the inputs are stored.

output

Boolean variable indicating if the memory layer is to store input in storage or extract from storage.

add

Boolean variable indicating if the extracted value are to be added or concatenated to the input. Only applicable when output = TRUE.

verbose

Boolean variable indicating if we want to give printouts to the user.

Details

If output = FALSE, this layer stores its input in the shared_env with the key id and then passes the input to the next layer. I.e., when memory layer is used in the masked encoder. If output = TRUE, this layer takes stored tensor from the storage. I.e., when memory layer is used in the decoder. If add = TRUE, it returns sum of the stored vector and an input, otherwise it returns their concatenation. If the tensor with specified id is not in storage when the layer with output = TRUE is called, it would cause an exception.

Author

Lars Henry Berge Olsen

Examples

if (FALSE) { # \dontrun{
memory_layer_env <- new.env()
net1 <- torch::nn_sequential(
  memory_layer("#1", shared_env = memory_layer_env),
  memory_layer("#0.1", shared_env = memory_layer_env),
  torch::nn_linear(512, 256),
  torch::nn_leaky_relu(), # Here add cannot be TRUE because the dimensions mismatch
  memory_layer("#0.1", shared_env = memory_layer_env, output = TRUE, add = FALSE),
  torch::nn_linear(768, 256),
  # the dimension after the concatenation with skip-connection is 512 + 256 = 768
)
net2 <- torch::nn_equential(
  torch::nn_linear(512, 512),
  memory_layer("#1", shared_env = memory_layer_env, output = TRUE, add = TRUE),
  ...
)
# Here a and c must be of correct dimensions, e.g., a = torch::torch_ones(1,512).
b <- net1(a)
d <- net2(c) # net2 must be called after net1, otherwise tensor '#1' will not be in storage.
} # }