torch::nn_module()
Representing a Memory LayerR/approach_vaeac_torch_modules.R
memory_layer.Rd
The layer is used to make skip-connections inside a torch::nn_sequential()
network
or between several torch::nn_sequential()
networks without unnecessary code complication.
memory_layer(id, shared_env, output = FALSE, add = FALSE, verbose = FALSE)
A unique id to use as a key in the storage list.
A shared environment for all instances of memory_layer where the inputs are stored.
Boolean variable indicating if the memory layer is to store input in storage or extract from storage.
Boolean variable indicating if the extracted value are to be added or concatenated to the input.
Only applicable when output = TRUE
.
Boolean variable indicating if we want to give printouts to the user.
If output = FALSE
, this layer stores its input in the shared_env
with the key id
and then
passes the input to the next layer. I.e., when memory layer is used in the masked encoder. If output = TRUE
, this
layer takes stored tensor from the storage. I.e., when memory layer is used in the decoder. If add = TRUE
, it
returns sum of the stored vector and an input
, otherwise it returns their concatenation. If the tensor with
specified id
is not in storage when the layer with output = TRUE
is called, it would cause an exception.
if (FALSE) { # \dontrun{
memory_layer_env <- new.env()
net1 <- torch::nn_sequential(
memory_layer("#1", shared_env = memory_layer_env),
memory_layer("#0.1", shared_env = memory_layer_env),
torch::nn_linear(512, 256),
torch::nn_leaky_relu(), # Here add cannot be TRUE because the dimensions mismatch
memory_layer("#0.1", shared_env = memory_layer_env, output = TRUE, add = FALSE),
torch::nn_linear(768, 256),
# the dimension after the concatenation with skip-connection is 512 + 256 = 768
)
net2 <- torch::nn_equential(
torch::nn_linear(512, 512),
memory_layer("#1", shared_env = memory_layer_env, output = TRUE, add = TRUE),
...
)
# Here a and c must be of correct dimensions, e.g., a = torch::torch_ones(1,512).
b <- net1(a)
d <- net2(c) # net2 must be called after net1, otherwise tensor '#1' will not be in storage.
} # }