VR4RoboticArm2/VR4RoboticArm/Library/PackageCache/com.meta.xr.sdk.movement/Runtime/Tracking/Scripts/A2E/Retargeter.cs
IonutMocanu 48cccc22ad Main2
2025-09-08 11:13:29 +03:00

325 lines
10 KiB
C#

// Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using UnityEngine;
namespace Meta.XR.Movement.FaceTracking.Samples
{
/// <summary>
/// Maintains input and output signals; has <see cref="Eval"/>
/// which allows producing output signals from supplied inputs.
/// </summary>
public class Retargeter
{
/// <summary>
/// List of input signals.
/// </summary>
public string[] InputSignals { get; private set; }
/// <summary>
/// List of output signals.
/// </summary>
public string[] OutputSignals { get; private set; }
private const float Eps = 1e-4f;
private struct Item
{
/// <summary>
/// Index field.
/// </summary>
public int Index;
/// <summary>
/// Weight field.
/// </summary>
public float Weight;
/// <summary>
/// Evaluate provided signal using weight.
/// </summary>
/// <param name="signal">Signal value.</param>
/// <returns>Signal as a fraction of weight or (1 - signal)/(1 - weight).</returns>
public float Eval(float signal)
{
if (Mathf.Abs(signal - Weight) < Eps) return 1.0f;
return signal <= Weight ? signal / Weight : (1f - signal) / (1f - Weight);
}
}
private struct Rule
{
/// <summary>
/// List of drivers.
/// </summary>
public Item[] Drivers;
/// <summary>
/// List of targets.
/// </summary>
public Item[] Targets;
/// <summary>
/// Obtains signal and target weights.
/// </summary>
/// <param name="signals">Signals list.</param>
/// <param name="targets">Targets list.</param>
public void Peak(float[] signals, float[] targets)
{
foreach (var dr in Drivers)
{
signals[dr.Index] = dr.Weight;
}
foreach (var t in Targets)
{
targets[t.Index] = t.Weight;
}
}
/// <summary>
/// Returns computed weight from (<see cref="IReadOnlyList"/>) signals.
/// </summary>
/// <param name="signals">Signals list.</param>
/// <returns>Computed weight.</returns>
public float Eval(float[] signals)
{
var weight = 1.0f;
foreach (var d in Drivers)
{
weight *= d.Eval(signals[d.Index]) * d.Weight;
}
return weight;
}
/// <summary>
/// Returns computed weight from (<see cref="IList"/>) signals.
/// </summary>
/// <param name="signals">Signals list.</param>
/// <returns>Computed weight.</returns>
public float Eval(IList<float> signals)
{
var weight = 1.0f;
foreach (var d in Drivers)
{
weight *= d.Eval(signals[d.Index]) * d.Weight;
}
return weight;
}
}
private Rule[] _rules;
private Matrix _deltas;
private float[] _activations;
private struct LoadedWeights
{
public List<Item> Items;
}
private static LoadedWeights LoadWeights(Dictionary<string, float> d, ref Dictionary<string, int> indices)
{
var items = new List<Item>();
foreach (var t in d)
{
var index = indices.Count;
if (indices.TryGetValue(t.Key, out var i))
{
index = i;
}
else
{
indices[t.Key] = index;
}
items.Add(new Item() { Index = index, Weight = t.Value });
}
return new LoadedWeights()
{
Items = items,
};
}
private static List<Rule> LoadV1(string json, ref Dictionary<string, int> signals, ref Dictionary<string, int> rigDrivers)
{
var mapping = JSONRigParser.DeserializeV1Mapping(json);
var rules = new List<Rule>();
foreach (var s in mapping)
{
var signalIndex = signals.Count;
if (signals.TryGetValue(s.Key, out var i))
{
signalIndex = i;
}
else
{
signals[s.Key] = signalIndex;
}
var targets = LoadWeights(s.Value, ref rigDrivers);
if (targets.Items.Count == 0)
{
continue;
}
var drivers = Enumerable.Repeat(new Item { Index = signalIndex, Weight = 1f }, 1).ToList();
rules.Add(new Rule() { Drivers = drivers.ToArray(), Targets = targets.Items.ToArray() });
}
return rules;
}
private static List<Rule> LoadV2(string json, ref Dictionary<string, int> signals, ref Dictionary<string, int> rigDrivers)
{
var mapping = JSONRigParser.DeserializeV2Mapping(json);
var rules = new List<Rule>();
foreach (var dt in mapping)
{
var drivers = LoadWeights(dt["drivers"], ref signals);
if (drivers.Items.Count == 0)
{
continue;
}
var targets = LoadWeights(dt["targets"], ref rigDrivers);
if (targets.Items.Count == 0)
{
continue;
}
rules.Add(new Rule() { Drivers = drivers.Items.ToArray(), Targets = targets.Items.ToArray() });
}
return rules;
}
private static string PrintMatrix(Matrix m)
{
var sb = new StringBuilder();
for (var r = 0; r < m.Rows; ++r)
{
for (var c = 0; c < m.Cols; ++c)
{
sb.Append(m[r, c].ToString("0.###"));
sb.Append(" ");
}
sb.AppendLine();
}
return sb.ToString();
}
private static string PrintVector(List<float> m)
{
var sb = new StringBuilder();
foreach (var a in m)
{
sb.Append($"{a} ");
}
return sb.ToString();
}
/// <summary>
/// Main retargeter constructor.
/// </summary>
/// <param name="json">JSON configuration.</param>
/// <param name="useSparseDeltaMatrix">Whether to use sparse delta matrix or not.</param>
public Retargeter(string json, bool useSparseDeltaMatrix = true)
{
var signals = new Dictionary<string, int>();
var rigDrivers = new Dictionary<string, int>();
// V1 setups root element is a dict, V2 setups root element is an array
var rulesList = json[0] == '{' ? LoadV1(json, ref signals, ref rigDrivers) : LoadV2(json, ref signals, ref rigDrivers);
var peaks = new DenseMatrix(rulesList.Count, signals.Count);
var targets = new DenseMatrix(rulesList.Count, rigDrivers.Count);
for (var i = 0; i < rulesList.Count; ++i)
{
rulesList[i].Peak(peaks.Row(i), targets.Row(i));
}
var m = new DenseMatrix(rulesList.Count, rulesList.Count);
for (var r = 0; r < m.Rows; ++r)
{
var rule = rulesList[r];
for (var c = 0; c < m.Cols; ++c)
{
m[r, c] = rule.Eval(peaks.Row(c));
}
}
_rules = rulesList.ToArray();
m.Invert();
m.Transpose();
_deltas = useSparseDeltaMatrix ? new SparseMatrix(DenseMatrix.Mult(m, targets)) : DenseMatrix.Mult(m, targets);
{
var inputSignals = Enumerable.Repeat("", signals.Count()).ToList();
foreach (var s in signals)
{
inputSignals[s.Value] = s.Key;
}
Debug.Assert(!inputSignals.Contains(""));
InputSignals = inputSignals.ToArray();
}
{
var outputSignals = Enumerable.Repeat("", rigDrivers.Count()).ToList();
foreach (var s in rigDrivers)
{
outputSignals[s.Value] = s.Key;
}
Debug.Assert(!outputSignals.Contains(""));
OutputSignals = outputSignals.ToArray();
}
_activations = Enumerable.Repeat(0.0f, _rules.Length).ToArray();
}
/// <summary>
/// Runs eval on input and output signals.
/// </summary>
/// <param name="signals">Input signals.</param>
/// <param name="outputs">Output signals.</param>
/// <exception cref="ArgumentException">Exception thrown if a problem is encountered.</exception>
public void Eval(float[] signals, float[] outputs)
{
if (signals.Length != InputSignals.Length)
{
throw new ArgumentException($"Expected {InputSignals.Length} input signals, got {signals.Length}");
}
if (outputs.Length != OutputSignals.Length)
{
throw new ArgumentException($"Expected {OutputSignals.Length} output signals, got {outputs.Length}");
}
for (var i = 0; i < _rules.Length; ++i)
{
_activations[i] = _rules[i].Eval(signals);
}
Matrix.Mult(_activations, _deltas, outputs);
for (var i = 0; i < outputs.Length; ++i)
{
outputs[i] = Mathf.Clamp(outputs[i], 0.0f, 1.0f);
}
}
}
}