Bases: ClassifierMixin, NeuralNet
Implementation of the NerualNet base class. See sklx.net.NerualNet for detailed documentation.
Source code in sklx/classifier.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 | class NeuralNetworkClassifier(ClassifierMixin, NeuralNet):
"""
Implementation of the NerualNet base class. See `sklx.net.NerualNet` for detailed documentation.
"""
module = None
max_epochs = 10
lr = 0.1
batch_size = 10
optimizer = optimizers.SGD
def __init__(
self,
module: nn.Module,
max_epochs: float,
lr: float,
criterion: Callable,
classes=None,
) -> None:
super().__init__(
module=module, criterion=criterion, lr=lr, max_epochs=max_epochs
)
self.classes = classes
@property
def classes_(self):
return np.array(self.classes)
|