A torch::nn_module()
Representing a Memory Layer
Source: R/approach_vaeac_torch_modules.R
memory_layer.Rd
The layer is used to make skip-connections inside a torch::nn_sequential()
network
or between several torch::nn_sequential()
networks without unnecessary code complication.
Arguments
- id
A unique id to use as a key in the storage list.
A shared environment for all instances of memory_layer where the inputs are stored.
- output
Boolean variable indicating if the memory layer is to store input in storage or extract from storage.
- add
Boolean variable indicating if the extracted value are to be added or concatenated to the input. Only applicable when
output = TRUE
.- verbose
Boolean variable indicating if we want to give printouts to the user.
Details
If output = FALSE
, this layer stores its input in the shared_env
with the key id
and then
passes the input to the next layer. I.e., when memory layer is used in the masked encoder. If output = TRUE
, this
layer takes stored tensor from the storage. I.e., when memory layer is used in the decoder. If add = TRUE
, it
returns sum of the stored vector and an input
, otherwise it returns their concatenation. If the tensor with
specified id
is not in storage when the layer with output = TRUE
is called, it would cause an exception.
Examples
if (FALSE) { # \dontrun{
memory_layer_env <- new.env()
net1 <- torch::nn_sequential(
memory_layer("#1", shared_env = memory_layer_env),
memory_layer("#0.1", shared_env = memory_layer_env),
torch::nn_linear(512, 256),
torch::nn_leaky_relu(), # Here add cannot be TRUE because the dimensions mismatch
memory_layer("#0.1", shared_env = memory_layer_env, output = TRUE, add = FALSE),
torch::nn_linear(768, 256),
# the dimension after the concatenation with skip-connection is 512 + 256 = 768
)
net2 <- torch::nn_equential(
torch::nn_linear(512, 512),
memory_layer("#1", shared_env = memory_layer_env, output = TRUE, add = TRUE),
...
)
# Here a and c must be of correct dimensions, e.g., a = torch::torch_ones(1,512).
b <- net1(a)
d <- net2(c) # net2 must be called after net1, otherwise tensor '#1' will not be in storage.
} # }