Skip to contents

Simple filter emulating caret::findCorrelation(exact = FALSE).

This gives each feature a score between 0 and 1 that is one minus the cutoff value for which it is excluded when using caret::findCorrelation(). The negative is used because caret::findCorrelation() excludes everything above a cutoff, while filters exclude everything below a cutoff. Here the filter scores are shifted by +1 to get positive values for to align with the way other filters work.

Subsequently caret::findCorrelation(cutoff = 0.9) lists the same features that are excluded with FilterFindCorrelation at score 0.1 (= 1 - 0.9).

Super class

mlr3filters::Filter -> FilterFindCorrelation

Methods

Inherited methods


Method new()

Create a FilterFindCorrelation object.

Usage


Method clone()

The objects of this class are cloneable with this method.

Usage

FilterFindCorrelation$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

# Pearson (default)
task = mlr3::tsk("mtcars")
filter = flt("find_correlation")
filter$calculate(task)
as.data.table(filter)
#>     feature      score
#>      <char>      <num>
#>  1:    carb 1.00000000
#>  2:    gear 0.72592716
#>  3:    qsec 0.34375077
#>  4:      wt 0.28755935
#>  5:    drat 0.28728887
#>  6:      vs 0.25546456
#>  7:      hp 0.25018753
#>  8:      am 0.20594124
#>  9:    disp 0.11202008
#> 10:     cyl 0.09796713

## Spearman
filter = flt("find_correlation", method = "spearman")
filter$calculate(task)
as.data.table(filter)
#>     feature      score
#>      <char>      <num>
#>  1:    qsec 1.00000000
#>  2:      am 0.79666789
#>  3:    carb 0.34128186
#>  4:    drat 0.25518383
#>  5:      hp 0.24840661
#>  6:      wt 0.22532327
#>  7:      vs 0.20842852
#>  8:    gear 0.19231200
#>  9:    disp 0.10229356
#> 10:     cyl 0.07234842

if (mlr3misc::require_namespaces(c("mlr3pipelines", "rpart"), quietly = TRUE)) {
  library("mlr3pipelines")
  task = mlr3::tsk("spam")

  # Note: `filter.frac` is selected randomly and should be tuned.

  graph = po("filter", filter = flt("find_correlation"), filter.frac = 0.5) %>>%
    po("learner", mlr3::lrn("classif.rpart"))

  graph$train(task)
}
#> $classif.rpart.output
#> NULL
#>