// Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using UnityEngine;
namespace Meta.XR.Movement.FaceTracking.Samples
{
///
/// Maintains input and output signals; has
/// which allows producing output signals from supplied inputs.
///
public class Retargeter
{
///
/// List of input signals.
///
public string[] InputSignals { get; private set; }
///
/// List of output signals.
///
public string[] OutputSignals { get; private set; }
private const float Eps = 1e-4f;
private struct Item
{
///
/// Index field.
///
public int Index;
///
/// Weight field.
///
public float Weight;
///
/// Evaluate provided signal using weight.
///
/// Signal value.
/// Signal as a fraction of weight or (1 - signal)/(1 - weight).
public float Eval(float signal)
{
if (Mathf.Abs(signal - Weight) < Eps) return 1.0f;
return signal <= Weight ? signal / Weight : (1f - signal) / (1f - Weight);
}
}
private struct Rule
{
///
/// List of drivers.
///
public Item[] Drivers;
///
/// List of targets.
///
public Item[] Targets;
///
/// Obtains signal and target weights.
///
/// Signals list.
/// Targets list.
public void Peak(float[] signals, float[] targets)
{
foreach (var dr in Drivers)
{
signals[dr.Index] = dr.Weight;
}
foreach (var t in Targets)
{
targets[t.Index] = t.Weight;
}
}
///
/// Returns computed weight from () signals.
///
/// Signals list.
/// Computed weight.
public float Eval(float[] signals)
{
var weight = 1.0f;
foreach (var d in Drivers)
{
weight *= d.Eval(signals[d.Index]) * d.Weight;
}
return weight;
}
///
/// Returns computed weight from () signals.
///
/// Signals list.
/// Computed weight.
public float Eval(IList signals)
{
var weight = 1.0f;
foreach (var d in Drivers)
{
weight *= d.Eval(signals[d.Index]) * d.Weight;
}
return weight;
}
}
private Rule[] _rules;
private Matrix _deltas;
private float[] _activations;
private struct LoadedWeights
{
public List- Items;
}
private static LoadedWeights LoadWeights(Dictionary d, ref Dictionary indices)
{
var items = new List
- ();
foreach (var t in d)
{
var index = indices.Count;
if (indices.TryGetValue(t.Key, out var i))
{
index = i;
}
else
{
indices[t.Key] = index;
}
items.Add(new Item() { Index = index, Weight = t.Value });
}
return new LoadedWeights()
{
Items = items,
};
}
private static List LoadV1(string json, ref Dictionary signals, ref Dictionary rigDrivers)
{
var mapping = JSONRigParser.DeserializeV1Mapping(json);
var rules = new List();
foreach (var s in mapping)
{
var signalIndex = signals.Count;
if (signals.TryGetValue(s.Key, out var i))
{
signalIndex = i;
}
else
{
signals[s.Key] = signalIndex;
}
var targets = LoadWeights(s.Value, ref rigDrivers);
if (targets.Items.Count == 0)
{
continue;
}
var drivers = Enumerable.Repeat(new Item { Index = signalIndex, Weight = 1f }, 1).ToList();
rules.Add(new Rule() { Drivers = drivers.ToArray(), Targets = targets.Items.ToArray() });
}
return rules;
}
private static List LoadV2(string json, ref Dictionary signals, ref Dictionary rigDrivers)
{
var mapping = JSONRigParser.DeserializeV2Mapping(json);
var rules = new List();
foreach (var dt in mapping)
{
var drivers = LoadWeights(dt["drivers"], ref signals);
if (drivers.Items.Count == 0)
{
continue;
}
var targets = LoadWeights(dt["targets"], ref rigDrivers);
if (targets.Items.Count == 0)
{
continue;
}
rules.Add(new Rule() { Drivers = drivers.Items.ToArray(), Targets = targets.Items.ToArray() });
}
return rules;
}
private static string PrintMatrix(Matrix m)
{
var sb = new StringBuilder();
for (var r = 0; r < m.Rows; ++r)
{
for (var c = 0; c < m.Cols; ++c)
{
sb.Append(m[r, c].ToString("0.###"));
sb.Append(" ");
}
sb.AppendLine();
}
return sb.ToString();
}
private static string PrintVector(List m)
{
var sb = new StringBuilder();
foreach (var a in m)
{
sb.Append($"{a} ");
}
return sb.ToString();
}
///
/// Main retargeter constructor.
///
/// JSON configuration.
/// Whether to use sparse delta matrix or not.
public Retargeter(string json, bool useSparseDeltaMatrix = true)
{
var signals = new Dictionary();
var rigDrivers = new Dictionary();
// V1 setups root element is a dict, V2 setups root element is an array
var rulesList = json[0] == '{' ? LoadV1(json, ref signals, ref rigDrivers) : LoadV2(json, ref signals, ref rigDrivers);
var peaks = new DenseMatrix(rulesList.Count, signals.Count);
var targets = new DenseMatrix(rulesList.Count, rigDrivers.Count);
for (var i = 0; i < rulesList.Count; ++i)
{
rulesList[i].Peak(peaks.Row(i), targets.Row(i));
}
var m = new DenseMatrix(rulesList.Count, rulesList.Count);
for (var r = 0; r < m.Rows; ++r)
{
var rule = rulesList[r];
for (var c = 0; c < m.Cols; ++c)
{
m[r, c] = rule.Eval(peaks.Row(c));
}
}
_rules = rulesList.ToArray();
m.Invert();
m.Transpose();
_deltas = useSparseDeltaMatrix ? new SparseMatrix(DenseMatrix.Mult(m, targets)) : DenseMatrix.Mult(m, targets);
{
var inputSignals = Enumerable.Repeat("", signals.Count()).ToList();
foreach (var s in signals)
{
inputSignals[s.Value] = s.Key;
}
Debug.Assert(!inputSignals.Contains(""));
InputSignals = inputSignals.ToArray();
}
{
var outputSignals = Enumerable.Repeat("", rigDrivers.Count()).ToList();
foreach (var s in rigDrivers)
{
outputSignals[s.Value] = s.Key;
}
Debug.Assert(!outputSignals.Contains(""));
OutputSignals = outputSignals.ToArray();
}
_activations = Enumerable.Repeat(0.0f, _rules.Length).ToArray();
}
///
/// Runs eval on input and output signals.
///
/// Input signals.
/// Output signals.
/// Exception thrown if a problem is encountered.
public void Eval(float[] signals, float[] outputs)
{
if (signals.Length != InputSignals.Length)
{
throw new ArgumentException($"Expected {InputSignals.Length} input signals, got {signals.Length}");
}
if (outputs.Length != OutputSignals.Length)
{
throw new ArgumentException($"Expected {OutputSignals.Length} output signals, got {outputs.Length}");
}
for (var i = 0; i < _rules.Length; ++i)
{
_activations[i] = _rules[i].Eval(signals);
}
Matrix.Mult(_activations, _deltas, outputs);
for (var i = 0; i < outputs.Length; ++i)
{
outputs[i] = Mathf.Clamp(outputs[i], 0.0f, 1.0f);
}
}
}
}