On this page:
l2-loss
cross-entropy-loss
kl-loss
8.12

15 Loss Functions🔗ℹ

Loss functions are curried and have the following type:
(target-fn? . -> . expectant-fn?)
where a target-fn? expects a tensor first, and then a theta?, and returns a tensor?. An expectant-fn? expects xs and ys which are two tensors representing a subset of the dataset, and returns an objective-fn?.

These are defined as follows
  • target-fn? : (-> tensor? (-> theta? tensor?))

  • expectant-fn? : (-> tensor? tensor? objective-fn?)

  • objective-fn? : (-> theta? tensor?)

The tensor returned from an objective-fn? must have rank 1, and its tlen should be the same as the number of elements in xs.

The following loss functions are available in malt.

procedure

(((l2-loss target) xs ys) θ)  tensor?

  target : (-> tensor? (-> theta? tensor?))
  xs : tensor?
  ys : tensor
  θ : theta?
Implements the SSE loss function.
(let ((pred-ys ((target xs) theta)))
  (sum
    (sqr
      (- ys pred-ys))))

procedure

(((cross-entropy-loss target) xs ys) θ)  tensor?

  target : (-> tensor? (-> theta? tensor?))
  xs : tensor?
  ys : tensor
  θ : theta?
Implements the cross-entropy loss function.
(let ((pred-ys ((target xs) theta))
      (num-classes (ref (reverse (shape ys)) 0)))
  (* -1
    (/ (dot-product ys (log pred-ys))
       num-classes)))

procedure

(((kl-loss target) xs ys) θ)  tensor?

  target : (-> tensor? (-> theta? tensor?))
  xs : tensor?
  ys : tensor
  θ : theta?
Implements the KL-divergence loss function.
(let ((pred-ys ((target xs) theta)))
  (sum (* pred-ys (log (/ pred-ys ys)))))