On this page:
5.1 Unary function extension rules
5.2 Binary function extension rules
ext1
ext2
5.3 Primitives in learner and nested-tensors
prim1
prim2
5.4 Primitives in flat-tensors
prim1
prim2
5.5 Extension in flat-tensors and nested-tensors
ext1-ρ
ext1-
ext2-ρ
ext2-
8.12

5 Extended Functions🔗ℹ

Base functions are commonly understood functions that only work on tensors of a specific rank known as the base rank. Base functions can be extended to work with tensors of any rank higher than the base rank.

Additionally, all the base and extended functions are differentiable (as described in Overview). Here we define a differentiable function as a Racket function that can be used with the automatic differentiation functions provided by Malt to produce valid gradients.

The following types, as described in Overview are assumed
  • primitive-1? A differentiable function of one argument that is constructed from the invocation of prim1.

  • primitive-2? A differentiable function of two arguments that is constructed from the invocation of prim2.

5.1 Unary function extension rules🔗ℹ

A unary base function f that accepts a tensor of rank base rank m can be extended to a tensor t of rank higher than m using the following rules.
  • If t is of rank m, the result is the tensor (f t).

  • Else for each element of t, say te, invoke the extended function on te and assemble the results into a tensor, in the order of the elements of t.

Unary functions are extended using ext1, ext1-ρ, or ext1-∇.

5.2 Binary function extension rules🔗ℹ

A binary base function f that accepts a tensor of base ranks m and n can be extended to tensors t and u higher than m and n by recursively descending into t and u with the following rules.
  • If t is of rank m and u is of rank n, the result is the tensor (f t u).

  • Else if t is of rank m, for each element of u, say ue, invoke the extended function on t and ue, and assemble the results into a tensor, in the order of the elements of u.

  • Else if u is of rank n, for each element of t, say te, invoke the extended function on te and u, and assemble the results into a tensor, in the order of the elements of t.

  • Else if u and t are of equal length, for each element of t, say te, and u, say ue, invoke the extended function on te and ue, and assemble the results into a tensor, in the order of the elements of t and u.

  • Else if the rank of u is higher than the rank of t, for each element of u, say ue, invoke the extended function on t and ue, and assemble the results into a tensor, in the order of the elements of u.

  • Else if the rank of t is higher than the rank of u, for each element of t, say te, invoke the extended function on te and u, and assemble the results into a tensor, in the order of the elements of t.

Binary functions are extended using ext2, ext2-ρ, or ext2-∇.

procedure

(ext1 prim base-rank)  primitive-1?

  prim : primitive-1?
  base-rank : natural?
Extends prim to operate on tensors of rank higher than base-rank. The returned function is also a primitive-1?.

procedure

(ext2 prim base-rank-1 base-rank-2)  primitive-2?

  prim : primitive-2?
  base-rank-1 : natural?
  base-rank-2 : natural?
Extends prim to operate on tensors of rank higher than base-rank-1 in the first argument and higher than base-rank-2 in the second argument. The returned function is also a primitive-2?.

5.3 Primitives in learner and nested-tensors🔗ℹ

The following functions are available when the default tensor representation in Malt is learner or nested-tensors.

procedure

(prim1 ρ-fn ∇-fn)  primitive-1?

  ρ-fn : (-> tensor? tensor?)
  ∇-fn : (-> tensor? tensor? tensor?)
Constructs a differentiable function (known as a primitive) of one tensor argument that invokes ρ-fn to compute the result of the application of the primitive, and uses ∇-fn to find the gradient of the result with respect to the argument provided to the primitive.

procedure

(prim2 ρ-fn ∇-fn)  primitive-2?

  ρ-fn : (-> tensor? tensor? tensor?)
  ∇-fn : (-> tensor? tensor? tensor? tensor?)
Constructs a differentiable function (known as a primitive) of two tensor arguments that invokes ρ-fn to compute the result of the application of the primitive, and uses ∇-fn to find the gradient of the result with respect to the two arguments provided to the primitive.

5.4 Primitives in flat-tensors🔗ℹ

The following functions are available when the default tensor representation in Malt is flat-tensors.

procedure

(prim1 ρ-fn ∇-fn shape-fn)  primitive-1?

  ρ-fn : (-> tensor? tensor?)
  ∇-fn : (-> tensor? tensor? tensor?)
  shape-fn : (-> shape? shape?)
Constructs a differentiable function (known as a primitive) of one tensor argument that invokes ρ-fn to compute the result of the application of the primitive, and uses ∇-fn to find the gradient of the result with respect the argument provided to the primitive. The third argument, shape-fn, when invoked with the shape corresponding to the shape of the argument to the primitive, provides the shape of the result from the invocation of ρ-fn on that argument.

procedure

(prim2 ρ-fn ∇-fn shape-fn)  primitive-2?

  ρ-fn : (-> tensor? tensor? tensor?)
  ∇-fn : (-> tensor? tensor? tensor? tensor?)
  shape-fn : (-> shape? shape? shape?)
Constructs a differentiable function (known as a primitive) of two tensor arguments that invokes ρ-fn to compute the result of the application of the primitive, and uses ∇-fn to find the gradient of the result with respect the arguments provided to the primitive. The third argument, shape-fn, when invoked with the shapes corresponding to the arguments to the primitive, provides the shape of the result from the invocation of ρ-fn on those two arguments.

5.5 Extension in flat-tensors and nested-tensors🔗ℹ

The following functions are available when the default tensor representation in Malt is flat-tensors or nested-tensors. These can be used to provide new tensor operations that cannot be constructed using ext1 and ext2. The following is a guideline, but the actual type of ρ-fn and ∇-fn arguments will vary by tensor representation. See the Malt source code for further examples.

procedure

(ext1-ρ ρ-fn base-rank)  (-> tensor? tensor?)

  ρ-fn : (-> tensor? tensor?)
  base-rank : natural?
Extends ρ-fn to operate on tensors of rank higher than base-rank. The returned function is not differentiable. ext1-ρ can be used to construct functions that can be the first argument to prim1.

procedure

(ext1-∇ ∇-fn base-rank)  (-> tensor? tensor? tensor?)

  ∇-fn : (-> tensor? tensor? tensor?)
  base-rank : natural?
Extends ∇-fn to operate on tensors of rank higher than base-rank. The returned function is not differentiable. ext1-∇ can be used to construct functions that can be the second argument to prim1.

procedure

(ext2-ρ ρ-fn base-rank-1 base-rank-2)

  (-> tensor? tensor? tensor?)
  ρ-fn : (-> tensor? tensor? tensor?)
  base-rank-1 : natural?
  base-rank-2 : natural?
Extends ρ-fn to operate on tensors of rank higher than base-rank-1 in the first argument and higher than base-rank-2 in the second argument. The returned function is not differentiable. ext2-ρ can be used to construct functions that can be the first argument to prim2.

procedure

(ext2-∇ ∇-fn base-rank)  (-> tensor? tensor? tensor?)

  ∇-fn : (-> tensor? tensor? tensor?)
  base-rank : natural?
Extends ∇-fn to operate on tensors of rank higher than base-rank-1 in the first argument and higher than base-rank-2 in the second argument. The returned function is not differentiable. ext2-∇ can be used to construct functions that can be the second argument to prim2.