Skip to contents

Generates predictions from a fitted Bayesian Neural Network (BNN) model.

Usage

# S3 method for class 'bnns'
predict(object, newdata = NULL, ...)

Arguments

object

An object of class "bnns", typically the result of a call to bnns.default.

newdata

A matrix or data frame of new input data for which predictions are required. If NULL, predictions are made on the training data used to fit the model.

...

Additional arguments (currently not used).

Value

A matrix/array of predicted values(regression)/probabilities(classification) where first dimension corresponds to the rows of newdata or the training data if newdata is NULL. Second dimension corresponds to the number of posterior samples. In case of out_act_fn = 3, the third dimension corresponds to the class.

Details

This function uses the posterior distribution from the Stan model in the bnns object to compute predictions for the provided input data.

See also

Examples

# \donttest{
# Example usage:
data <- data.frame(x1 = runif(10), x2 = runif(10), y = rnorm(10))
model <- bnns(y ~ -1 + x1 + x2,
  data = data, L = 1, nodes = 2, act_fn = 2,
  iter = 1e1, warmup = 5, chains = 1
)
#> 
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 1).
#> Chain 1: 
#> Chain 1: Gradient evaluation took 1.8e-05 seconds
#> Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.18 seconds.
#> Chain 1: Adjust your expectations accordingly!
#> Chain 1: 
#> Chain 1: 
#> Chain 1: WARNING: No variance estimation is
#> Chain 1:          performed for num_warmup < 20
#> Chain 1: 
#> Chain 1: Iteration: 1 / 10 [ 10%]  (Warmup)
#> Chain 1: Iteration: 2 / 10 [ 20%]  (Warmup)
#> Chain 1: Iteration: 3 / 10 [ 30%]  (Warmup)
#> Chain 1: Iteration: 4 / 10 [ 40%]  (Warmup)
#> Chain 1: Iteration: 5 / 10 [ 50%]  (Warmup)
#> Chain 1: Iteration: 6 / 10 [ 60%]  (Sampling)
#> Chain 1: Iteration: 7 / 10 [ 70%]  (Sampling)
#> Chain 1: Iteration: 8 / 10 [ 80%]  (Sampling)
#> Chain 1: Iteration: 9 / 10 [ 90%]  (Sampling)
#> Chain 1: Iteration: 10 / 10 [100%]  (Sampling)
#> Chain 1: 
#> Chain 1:  Elapsed Time: 0 seconds (Warm-up)
#> Chain 1:                0 seconds (Sampling)
#> Chain 1:                0 seconds (Total)
#> Chain 1: 
new_data <- data.frame(x1 = runif(5), x2 = runif(5))
predictions <- predict(model, newdata = new_data)
print(predictions)
#>           [,1]      [,2]       [,3]       [,4]      [,5]
#> [1,] 0.7574778 1.2453539  1.1224906  0.9577392 0.6537702
#> [2,] 0.4410329 0.8112508 -0.7868859 -0.1728251 1.2776497
#> [3,] 0.9894079 1.0748221  1.5610006  1.2906257 0.6105816
#> [4,] 0.4554538 1.0534120 -0.3056047  0.0328638 0.9640978
#> [5,] 0.3473748 1.2600905 -0.3345410 -0.1466174 0.8629016
# }