Source code for mixturemapping.layers.linearmapping

## Copyright 2020-2024 Viktor Krueckl. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

_B='weights'
_A='means'
import tensorflow as _tf,tensorflow_probability as _tfp
_tfd=_tfp.distributions
from._mapping import _M
from..distributions import regularizeCovMatrix as _r
[docs] class LinearMapping(_M):
[docs] def build(A,input_shapes): D=True;C='normal';B=input_shapes if not(_A in B and _B in B):raise Exception('means and weights are needed to construct the mapping layer!') A.mix_dim=B[_A][1];A.input_dim=B[_A][2];A.kernel_mean=A.add_weight(name='kernel_mean',shape=(A.input_dim,A.output_dim),initializer=C,trainable=D);A.kernel_std=A.add_weight(name='kernel_std',shape=(A.input_dim,A.output_dim),initializer=C,trainable=D);A.bias_mean=A.add_weight(name='bias_mean',shape=(A.output_dim,),initializer=C,trainable=D);A.bias_std=A.add_weight(name='bias_std',shape=(A.output_dim,),initializer=C,trainable=D);A.kernel=_tfd.Normal(A.kernel_mean,_tf.exp(A.kernel_std));A.bias=_tfd.Normal(A.bias_mean,_tf.exp(A.bias_std));super(LinearMapping,A).build(B[_A])
[docs] @_tf.function(autograph=False) def call(self,x,**M): I='covariances';H='stdDevs';D='covB';C='covA';A=self;F=_tf.shape(x[_A])[0];J=_tf.cond(A.sampling,lambda:_tf.matmul(x[_A],A.kernel.sample(F))+_tf.expand_dims(A.bias.sample(F),1),lambda:_tf.matmul(x[_A],A.kernel_mean)+A.bias_mean);G=_tf.expand_dims(_tf.expand_dims(_tf.stop_gradient(A.kernel_mean),0),0) if H in x:K=_tf.transpose(_tf.expand_dims(x[H],2),[0,1,3,2]);E=_tf.multiply(K,G) else:L=_tf.linalg.sqrtm(x[I]);E=_tf.matmul(L,G) A.mappingCov=_tf.matmul(_tf.transpose(E,[0,1,3,2]),E);B=A.mappingCov if C in x: if len(x[C].shape.dims)==3:B=B+_tf.expand_dims(x[C],1) else:B=B+x[C] if D in x: if len(x[D].shape.dims)==3:B=B+_tf.expand_dims(x[D],1) else:B=B+x[D] if A._regularize_cov_epsilon:A.covMatrix=_r(B,A._regularize_cov_epsilon) else:A.covMatrix=B return{_A:J,I:A.covMatrix,_B:x[_B]}