// Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. using System; using System.Collections.Generic; using System.Linq; using System.Text; using UnityEngine; namespace Meta.XR.Movement.FaceTracking.Samples { /// /// Maintains input and output signals; has /// which allows producing output signals from supplied inputs. /// public class Retargeter { /// /// List of input signals. /// public string[] InputSignals { get; private set; } /// /// List of output signals. /// public string[] OutputSignals { get; private set; } private const float Eps = 1e-4f; private struct Item { /// /// Index field. /// public int Index; /// /// Weight field. /// public float Weight; /// /// Evaluate provided signal using weight. /// /// Signal value. /// Signal as a fraction of weight or (1 - signal)/(1 - weight). public float Eval(float signal) { if (Mathf.Abs(signal - Weight) < Eps) return 1.0f; return signal <= Weight ? signal / Weight : (1f - signal) / (1f - Weight); } } private struct Rule { /// /// List of drivers. /// public Item[] Drivers; /// /// List of targets. /// public Item[] Targets; /// /// Obtains signal and target weights. /// /// Signals list. /// Targets list. public void Peak(float[] signals, float[] targets) { foreach (var dr in Drivers) { signals[dr.Index] = dr.Weight; } foreach (var t in Targets) { targets[t.Index] = t.Weight; } } /// /// Returns computed weight from () signals. /// /// Signals list. /// Computed weight. public float Eval(float[] signals) { var weight = 1.0f; foreach (var d in Drivers) { weight *= d.Eval(signals[d.Index]) * d.Weight; } return weight; } /// /// Returns computed weight from () signals. /// /// Signals list. /// Computed weight. public float Eval(IList signals) { var weight = 1.0f; foreach (var d in Drivers) { weight *= d.Eval(signals[d.Index]) * d.Weight; } return weight; } } private Rule[] _rules; private Matrix _deltas; private float[] _activations; private struct LoadedWeights { public List Items; } private static LoadedWeights LoadWeights(Dictionary d, ref Dictionary indices) { var items = new List(); foreach (var t in d) { var index = indices.Count; if (indices.TryGetValue(t.Key, out var i)) { index = i; } else { indices[t.Key] = index; } items.Add(new Item() { Index = index, Weight = t.Value }); } return new LoadedWeights() { Items = items, }; } private static List LoadV1(string json, ref Dictionary signals, ref Dictionary rigDrivers) { var mapping = JSONRigParser.DeserializeV1Mapping(json); var rules = new List(); foreach (var s in mapping) { var signalIndex = signals.Count; if (signals.TryGetValue(s.Key, out var i)) { signalIndex = i; } else { signals[s.Key] = signalIndex; } var targets = LoadWeights(s.Value, ref rigDrivers); if (targets.Items.Count == 0) { continue; } var drivers = Enumerable.Repeat(new Item { Index = signalIndex, Weight = 1f }, 1).ToList(); rules.Add(new Rule() { Drivers = drivers.ToArray(), Targets = targets.Items.ToArray() }); } return rules; } private static List LoadV2(string json, ref Dictionary signals, ref Dictionary rigDrivers) { var mapping = JSONRigParser.DeserializeV2Mapping(json); var rules = new List(); foreach (var dt in mapping) { var drivers = LoadWeights(dt["drivers"], ref signals); if (drivers.Items.Count == 0) { continue; } var targets = LoadWeights(dt["targets"], ref rigDrivers); if (targets.Items.Count == 0) { continue; } rules.Add(new Rule() { Drivers = drivers.Items.ToArray(), Targets = targets.Items.ToArray() }); } return rules; } private static string PrintMatrix(Matrix m) { var sb = new StringBuilder(); for (var r = 0; r < m.Rows; ++r) { for (var c = 0; c < m.Cols; ++c) { sb.Append(m[r, c].ToString("0.###")); sb.Append(" "); } sb.AppendLine(); } return sb.ToString(); } private static string PrintVector(List m) { var sb = new StringBuilder(); foreach (var a in m) { sb.Append($"{a} "); } return sb.ToString(); } /// /// Main retargeter constructor. /// /// JSON configuration. /// Whether to use sparse delta matrix or not. public Retargeter(string json, bool useSparseDeltaMatrix = true) { var signals = new Dictionary(); var rigDrivers = new Dictionary(); // V1 setups root element is a dict, V2 setups root element is an array var rulesList = json[0] == '{' ? LoadV1(json, ref signals, ref rigDrivers) : LoadV2(json, ref signals, ref rigDrivers); var peaks = new DenseMatrix(rulesList.Count, signals.Count); var targets = new DenseMatrix(rulesList.Count, rigDrivers.Count); for (var i = 0; i < rulesList.Count; ++i) { rulesList[i].Peak(peaks.Row(i), targets.Row(i)); } var m = new DenseMatrix(rulesList.Count, rulesList.Count); for (var r = 0; r < m.Rows; ++r) { var rule = rulesList[r]; for (var c = 0; c < m.Cols; ++c) { m[r, c] = rule.Eval(peaks.Row(c)); } } _rules = rulesList.ToArray(); m.Invert(); m.Transpose(); _deltas = useSparseDeltaMatrix ? new SparseMatrix(DenseMatrix.Mult(m, targets)) : DenseMatrix.Mult(m, targets); { var inputSignals = Enumerable.Repeat("", signals.Count()).ToList(); foreach (var s in signals) { inputSignals[s.Value] = s.Key; } Debug.Assert(!inputSignals.Contains("")); InputSignals = inputSignals.ToArray(); } { var outputSignals = Enumerable.Repeat("", rigDrivers.Count()).ToList(); foreach (var s in rigDrivers) { outputSignals[s.Value] = s.Key; } Debug.Assert(!outputSignals.Contains("")); OutputSignals = outputSignals.ToArray(); } _activations = Enumerable.Repeat(0.0f, _rules.Length).ToArray(); } /// /// Runs eval on input and output signals. /// /// Input signals. /// Output signals. /// Exception thrown if a problem is encountered. public void Eval(float[] signals, float[] outputs) { if (signals.Length != InputSignals.Length) { throw new ArgumentException($"Expected {InputSignals.Length} input signals, got {signals.Length}"); } if (outputs.Length != OutputSignals.Length) { throw new ArgumentException($"Expected {OutputSignals.Length} output signals, got {outputs.Length}"); } for (var i = 0; i < _rules.Length; ++i) { _activations[i] = _rules[i].Eval(signals); } Matrix.Mult(_activations, _deltas, outputs); for (var i = 0; i < outputs.Length; ++i) { outputs[i] = Mathf.Clamp(outputs[i], 0.0f, 1.0f); } } } }