Generates predictions from a fitted Bayesian Neural Network (BNN) model.
Usage
# S3 method for class 'bnns'
predict(object, newdata = NULL, ...)
Arguments
- object
An object of class
"bnns"
, typically the result of a call tobnns.default
.- newdata
A matrix or data frame of new input data for which predictions are required. If
NULL
, predictions are made on the training data used to fit the model.- ...
Additional arguments (currently not used).
Value
A matrix/array of predicted values(regression)/probabilities(classification) where first dimension corresponds to the rows of newdata
or the training data if newdata
is NULL
. Second dimension corresponds to the number of posterior samples. In case of out_act_fn = 3
, the third dimension corresponds to the class.
Details
This function uses the posterior distribution from the Stan model in the bnns
object to compute predictions for the provided input data.
Examples
# \donttest{
# Example usage:
data <- data.frame(x1 = runif(10), x2 = runif(10), y = rnorm(10))
model <- bnns(y ~ -1 + x1 + x2,
data = data, L = 1, nodes = 2, act_fn = 2,
iter = 1e1, warmup = 5, chains = 1
)
#>
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 1).
#> Chain 1:
#> Chain 1: Gradient evaluation took 1.8e-05 seconds
#> Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.18 seconds.
#> Chain 1: Adjust your expectations accordingly!
#> Chain 1:
#> Chain 1:
#> Chain 1: WARNING: No variance estimation is
#> Chain 1: performed for num_warmup < 20
#> Chain 1:
#> Chain 1: Iteration: 1 / 10 [ 10%] (Warmup)
#> Chain 1: Iteration: 2 / 10 [ 20%] (Warmup)
#> Chain 1: Iteration: 3 / 10 [ 30%] (Warmup)
#> Chain 1: Iteration: 4 / 10 [ 40%] (Warmup)
#> Chain 1: Iteration: 5 / 10 [ 50%] (Warmup)
#> Chain 1: Iteration: 6 / 10 [ 60%] (Sampling)
#> Chain 1: Iteration: 7 / 10 [ 70%] (Sampling)
#> Chain 1: Iteration: 8 / 10 [ 80%] (Sampling)
#> Chain 1: Iteration: 9 / 10 [ 90%] (Sampling)
#> Chain 1: Iteration: 10 / 10 [100%] (Sampling)
#> Chain 1:
#> Chain 1: Elapsed Time: 0 seconds (Warm-up)
#> Chain 1: 0 seconds (Sampling)
#> Chain 1: 0 seconds (Total)
#> Chain 1:
new_data <- data.frame(x1 = runif(5), x2 = runif(5))
predictions <- predict(model, newdata = new_data)
print(predictions)
#> [,1] [,2] [,3] [,4] [,5]
#> [1,] 0.7574778 1.2453539 1.1224906 0.9577392 0.6537702
#> [2,] 0.4410329 0.8112508 -0.7868859 -0.1728251 1.2776497
#> [3,] 0.9894079 1.0748221 1.5610006 1.2906257 0.6105816
#> [4,] 0.4554538 1.0534120 -0.3056047 0.0328638 0.9640978
#> [5,] 0.3473748 1.2600905 -0.3345410 -0.1466174 0.8629016
# }