mirror of
https://github.com/deepmodeling/Uni-Lab-OS
synced 2026-04-23 22:39:59 +00:00
Compare commits
4 Commits
5dca3d8c3d
...
feat/3d_bu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9862415655 | ||
|
|
18296d3cb2 | ||
|
|
090d5c5cb5 | ||
|
|
48e13a7b4d |
172
tests/app/__init__.py
Normal file
172
tests/app/__init__.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""normalize_model_for_upload 单元测试"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from unilabos.app.register import normalize_model_for_upload
|
||||
|
||||
|
||||
class TestNormalizeModelForUpload(unittest.TestCase):
|
||||
"""测试 Registry YAML model 字段标准化"""
|
||||
|
||||
def test_empty_input(self):
|
||||
"""空 dict 直接返回"""
|
||||
self.assertEqual(normalize_model_for_upload({}), {})
|
||||
self.assertIsNone(normalize_model_for_upload(None))
|
||||
|
||||
def test_format_infer_xacro(self):
|
||||
"""自动从 path 后缀推断 format=xacro"""
|
||||
model = {
|
||||
"path": "https://oss.example.com/devices/arm/macro_device.xacro",
|
||||
"mesh": "arm_slider",
|
||||
"type": "device",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "xacro")
|
||||
|
||||
def test_format_infer_urdf(self):
|
||||
"""自动推断 format=urdf"""
|
||||
model = {"path": "https://example.com/robot.urdf", "type": "device"}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "urdf")
|
||||
|
||||
def test_format_infer_stl(self):
|
||||
"""自动推断 format=stl"""
|
||||
model = {"path": "https://example.com/part.stl"}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "stl")
|
||||
|
||||
def test_format_infer_gltf(self):
|
||||
"""自动推断 format=gltf(.gltf 和 .glb)"""
|
||||
for ext in [".gltf", ".glb"]:
|
||||
model = {"path": f"https://example.com/model{ext}"}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "gltf", f"failed for {ext}")
|
||||
|
||||
def test_format_not_overwritten(self):
|
||||
"""已有 format 字段时不覆盖"""
|
||||
model = {
|
||||
"path": "https://example.com/model.xacro",
|
||||
"format": "custom",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "custom")
|
||||
|
||||
def test_format_no_path(self):
|
||||
"""没有 path 时不推断 format"""
|
||||
model = {"mesh": "arm_slider", "type": "device"}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertNotIn("format", result)
|
||||
|
||||
def test_children_mesh_string_to_struct(self):
|
||||
"""将 children_mesh 字符串(旧格式)转为结构化对象"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"type": "resource",
|
||||
"children_mesh": "tip/meshes/tip.stl",
|
||||
"children_mesh_tf": [0.0045, 0.0045, 0, 0, 0, 1.57],
|
||||
"children_mesh_path": "https://oss.example.com/tip.stl",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
|
||||
# children_mesh 应变为 dict
|
||||
cm = result["children_mesh"]
|
||||
self.assertIsInstance(cm, dict)
|
||||
self.assertEqual(cm["path"], "https://oss.example.com/tip.stl") # 优先使用 OSS URL
|
||||
self.assertEqual(cm["format"], "stl")
|
||||
self.assertTrue(cm["default_visible"])
|
||||
self.assertEqual(cm["local_offset"], [0.0045, 0.0045, 0])
|
||||
self.assertEqual(cm["local_rotation"], [0, 0, 1.57])
|
||||
|
||||
# 旧字段应被移除
|
||||
self.assertNotIn("children_mesh_tf", result)
|
||||
self.assertNotIn("children_mesh_path", result)
|
||||
|
||||
def test_children_mesh_no_oss_fallback(self):
|
||||
"""children_mesh 无 OSS URL 时 fallback 到本地路径"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": "plate_96/meshes/plate_96.stl",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
cm = result["children_mesh"]
|
||||
self.assertEqual(cm["path"], "plate_96/meshes/plate_96.stl")
|
||||
self.assertEqual(cm["format"], "stl")
|
||||
|
||||
def test_children_mesh_gltf_format(self):
|
||||
"""children_mesh .glb 文件推断 format=gltf"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": "meshes/child.glb",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["children_mesh"]["format"], "gltf")
|
||||
|
||||
def test_children_mesh_partial_tf(self):
|
||||
"""children_mesh_tf 只有 3 个值时只有 offset 无 rotation"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": "tip.stl",
|
||||
"children_mesh_tf": [0.01, 0.02, 0.03],
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
cm = result["children_mesh"]
|
||||
self.assertEqual(cm["local_offset"], [0.01, 0.02, 0.03])
|
||||
self.assertNotIn("local_rotation", cm)
|
||||
|
||||
def test_children_mesh_no_tf(self):
|
||||
"""children_mesh 无 tf 时不加 offset/rotation"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": "tip.stl",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
cm = result["children_mesh"]
|
||||
self.assertNotIn("local_offset", cm)
|
||||
self.assertNotIn("local_rotation", cm)
|
||||
|
||||
def test_children_mesh_already_dict(self):
|
||||
"""children_mesh 已经是 dict 时不重新映射"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": {
|
||||
"path": "https://example.com/tip.stl",
|
||||
"format": "stl",
|
||||
"default_visible": False,
|
||||
},
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
cm = result["children_mesh"]
|
||||
self.assertIsInstance(cm, dict)
|
||||
self.assertFalse(cm["default_visible"])
|
||||
|
||||
def test_original_not_mutated(self):
|
||||
"""原始 dict 不被修改"""
|
||||
original = {
|
||||
"path": "https://example.com/model.xacro",
|
||||
"mesh": "arm",
|
||||
}
|
||||
original_copy = {**original}
|
||||
normalize_model_for_upload(original)
|
||||
self.assertEqual(original, original_copy)
|
||||
|
||||
def test_preserves_existing_fields(self):
|
||||
"""所有原始字段都被保留"""
|
||||
model = {
|
||||
"path": "https://example.com/model.xacro",
|
||||
"mesh": "arm_slider",
|
||||
"type": "device",
|
||||
"mesh_tf": [0, 0, 0, 0, 0, 0],
|
||||
"custom_field": "should_survive",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["custom_field"], "should_survive")
|
||||
self.assertEqual(result["mesh_tf"], [0, 0, 0, 0, 0, 0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
496
tests/app/test_model_upload.py
Normal file
496
tests/app/test_model_upload.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""model_upload.py 单元测试(upload_device_model / download_model_from_oss / XOR 加解密)"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from unilabos.app.model_upload import (
|
||||
upload_device_model,
|
||||
download_model_from_oss,
|
||||
_MODEL_EXTENSIONS,
|
||||
_MESH_ENCRYPT_EXTENSIONS,
|
||||
_xor_transform,
|
||||
)
|
||||
|
||||
|
||||
class TestUploadDeviceModel(unittest.TestCase):
|
||||
"""测试本地模型文件上传到 OSS"""
|
||||
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.mkdtemp()
|
||||
self.mock_client = MagicMock()
|
||||
|
||||
def _create_model_files(self, subdir: str, filenames: list[str]):
|
||||
"""在临时目录中创建设备模型文件"""
|
||||
model_dir = Path(self.tmp_dir) / "devices" / subdir
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
for name in filenames:
|
||||
p = model_dir / name
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_text("dummy content")
|
||||
return model_dir
|
||||
|
||||
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||||
def test_upload_success(self, mock_base):
|
||||
"""正常上传流程"""
|
||||
mock_base.__truediv__ = lambda self, x: Path(self.tmp_dir) / x
|
||||
# 直接 patch _MESH_BASE_DIR 为 Path(tmp_dir)
|
||||
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||||
self._create_model_files("arm_slider", ["macro_device.xacro", "meshes/link1.stl"])
|
||||
|
||||
self.mock_client.get_model_upload_urls.return_value = {
|
||||
"files": [
|
||||
{"name": "macro_device.xacro", "upload_url": "https://oss.example.com/put1"},
|
||||
{"name": "meshes/link1.stl", "upload_url": "https://oss.example.com/put2"},
|
||||
]
|
||||
}
|
||||
self.mock_client.publish_model.return_value = {
|
||||
"path": "https://oss.example.com/arm_slider/macro_device.xacro"
|
||||
}
|
||||
|
||||
with patch("unilabos.app.model_upload._put_upload") as mock_put:
|
||||
result = upload_device_model(
|
||||
http_client=self.mock_client,
|
||||
template_uuid="test-uuid",
|
||||
mesh_name="arm_slider",
|
||||
model_type="device",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
self.assertEqual(result, "https://oss.example.com/arm_slider/macro_device.xacro")
|
||||
self.mock_client.get_model_upload_urls.assert_called_once()
|
||||
self.mock_client.publish_model.assert_called_once()
|
||||
|
||||
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||||
def test_upload_dir_not_exists(self, mock_base):
|
||||
"""本地目录不存在时返回 None"""
|
||||
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||||
result = upload_device_model(
|
||||
http_client=self.mock_client,
|
||||
template_uuid="test-uuid",
|
||||
mesh_name="nonexistent",
|
||||
model_type="device",
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||||
def test_upload_no_valid_files(self, mock_base):
|
||||
"""目录中无有效模型文件时返回 None"""
|
||||
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||||
model_dir = Path(self.tmp_dir) / "devices" / "empty_model"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
(model_dir / "readme.txt").write_text("not a model")
|
||||
|
||||
result = upload_device_model(
|
||||
http_client=self.mock_client,
|
||||
template_uuid="test-uuid",
|
||||
mesh_name="empty_model",
|
||||
model_type="device",
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("unilabos.app.model_upload._MESH_BASE_DIR")
|
||||
def test_upload_urls_failure(self, mock_base):
|
||||
"""获取上传 URL 失败时返回 None"""
|
||||
with patch("unilabos.app.model_upload._MESH_BASE_DIR", Path(self.tmp_dir)):
|
||||
self._create_model_files("arm", ["device.xacro"])
|
||||
self.mock_client.get_model_upload_urls.return_value = None
|
||||
|
||||
result = upload_device_model(
|
||||
http_client=self.mock_client,
|
||||
template_uuid="test-uuid",
|
||||
mesh_name="arm",
|
||||
model_type="device",
|
||||
)
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestDownloadModelFromOss(unittest.TestCase):
|
||||
"""测试从 OSS 下载模型文件到本地"""
|
||||
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.mkdtemp()
|
||||
|
||||
def test_skip_no_mesh_name(self):
|
||||
"""缺少 mesh 名称时跳过"""
|
||||
result = download_model_from_oss({"type": "device", "path": "https://x.com/a.xacro"})
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_skip_no_oss_path(self):
|
||||
"""缺少 OSS path 时跳过"""
|
||||
result = download_model_from_oss({"mesh": "arm", "type": "device"})
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_skip_local_path(self):
|
||||
"""非 https:// 路径时跳过"""
|
||||
result = download_model_from_oss({
|
||||
"mesh": "arm",
|
||||
"type": "device",
|
||||
"path": "file:///local/model.xacro",
|
||||
})
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_already_exists(self):
|
||||
"""本地已有文件时跳过下载"""
|
||||
device_dir = Path(self.tmp_dir) / "devices" / "arm"
|
||||
device_dir.mkdir(parents=True, exist_ok=True)
|
||||
(device_dir / "model.xacro").write_text("existing")
|
||||
|
||||
result = download_model_from_oss(
|
||||
{"mesh": "arm", "type": "device", "path": "https://oss.example.com/model.xacro"},
|
||||
mesh_base_dir=Path(self.tmp_dir),
|
||||
)
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch("unilabos.app.model_upload._download_file")
|
||||
def test_download_device(self, mock_download):
|
||||
"""下载 device 模型到 devices/ 目录"""
|
||||
result = download_model_from_oss(
|
||||
{"mesh": "new_arm", "type": "device", "path": "https://oss.example.com/new_arm/macro_device.xacro"},
|
||||
mesh_base_dir=Path(self.tmp_dir),
|
||||
)
|
||||
self.assertTrue(result)
|
||||
mock_download.assert_called_once()
|
||||
call_args = mock_download.call_args
|
||||
self.assertIn("macro_device.xacro", str(call_args[0][1]))
|
||||
|
||||
@patch("unilabos.app.model_upload._download_file")
|
||||
def test_download_resource(self, mock_download):
|
||||
"""下载 resource 模型到 resources/ 目录"""
|
||||
result = download_model_from_oss(
|
||||
{
|
||||
"mesh": "plate_96/meshes/plate_96.stl",
|
||||
"type": "resource",
|
||||
"path": "https://oss.example.com/plate_96/modal.xacro",
|
||||
},
|
||||
mesh_base_dir=Path(self.tmp_dir),
|
||||
)
|
||||
self.assertTrue(result)
|
||||
target_dir = Path(self.tmp_dir) / "resources" / "plate_96"
|
||||
self.assertTrue(target_dir.exists())
|
||||
|
||||
@patch("unilabos.app.model_upload._download_file")
|
||||
def test_download_with_children_mesh(self, mock_download):
|
||||
"""下载包含 children_mesh 的模型"""
|
||||
result = download_model_from_oss(
|
||||
{
|
||||
"mesh": "tip_rack",
|
||||
"type": "device",
|
||||
"path": "https://oss.example.com/tip_rack/model.xacro",
|
||||
"children_mesh": {
|
||||
"path": "https://oss.example.com/tip_rack/meshes/tip.stl",
|
||||
"format": "stl",
|
||||
},
|
||||
},
|
||||
mesh_base_dir=Path(self.tmp_dir),
|
||||
)
|
||||
self.assertTrue(result)
|
||||
# 应调用两次:入口文件 + children_mesh
|
||||
self.assertEqual(mock_download.call_count, 2)
|
||||
|
||||
@patch("unilabos.app.model_upload._download_file", side_effect=Exception("network error"))
|
||||
def test_download_failure_graceful(self, mock_download):
|
||||
"""下载失败时返回 False(不抛异常)"""
|
||||
result = download_model_from_oss(
|
||||
{"mesh": "broken", "type": "device", "path": "https://oss.example.com/broken.xacro"},
|
||||
mesh_base_dir=Path(self.tmp_dir),
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestModelExtensions(unittest.TestCase):
|
||||
"""测试支持的模型文件后缀集合"""
|
||||
|
||||
def test_standard_extensions(self):
|
||||
"""确认标准 3D 格式在支持列表中"""
|
||||
expected = {".stl", ".gltf", ".glb", ".xacro", ".urdf", ".obj", ".dae"}
|
||||
for ext in expected:
|
||||
self.assertIn(ext, _MODEL_EXTENSIONS, f"{ext} should be supported")
|
||||
|
||||
def test_non_model_excluded(self):
|
||||
"""非模型文件后缀不在列表中"""
|
||||
excluded = {".txt", ".json", ".py", ".png", ".jpg"}
|
||||
for ext in excluded:
|
||||
self.assertNotIn(ext, _MODEL_EXTENSIONS, f"{ext} should not be supported")
|
||||
|
||||
|
||||
class TestXorTransform(unittest.TestCase):
|
||||
"""XOR 加密/解密核心函数测试。"""
|
||||
|
||||
def test_roundtrip_symmetry(self):
|
||||
"""XOR 加密后再解密恢复原始数据(对称性)。"""
|
||||
original = b"Hello, this is a test model file content."
|
||||
encrypted = _xor_transform(original)
|
||||
self.assertNotEqual(encrypted, original)
|
||||
decrypted = _xor_transform(encrypted)
|
||||
self.assertEqual(decrypted, original)
|
||||
|
||||
def test_empty_data(self):
|
||||
"""空数据加密后仍为空。"""
|
||||
result = _xor_transform(b"")
|
||||
self.assertEqual(result, b"")
|
||||
|
||||
def test_single_byte(self):
|
||||
"""单字节数据正确加解密。"""
|
||||
original = b"\xff"
|
||||
encrypted = _xor_transform(original)
|
||||
decrypted = _xor_transform(encrypted)
|
||||
self.assertEqual(decrypted, original)
|
||||
|
||||
def test_data_longer_than_key(self):
|
||||
"""超过密钥长度(32 字节)的数据正确循环 XOR。"""
|
||||
original = bytes(range(256)) * 2 # 512 字节
|
||||
encrypted = _xor_transform(original)
|
||||
self.assertNotEqual(encrypted, original)
|
||||
decrypted = _xor_transform(encrypted)
|
||||
self.assertEqual(decrypted, original)
|
||||
|
||||
def test_data_exactly_key_length(self):
|
||||
"""恰好 32 字节(密钥长度)的数据正确处理。"""
|
||||
original = bytes(range(32))
|
||||
encrypted = _xor_transform(original)
|
||||
decrypted = _xor_transform(encrypted)
|
||||
self.assertEqual(decrypted, original)
|
||||
|
||||
def test_all_zeros_produces_key(self):
|
||||
"""全零数据 XOR 后结果应为密钥本身。"""
|
||||
zeros = b"\x00" * 32
|
||||
result = _xor_transform(zeros)
|
||||
key = os.environ.get(
|
||||
"UNILAB_MESH_XOR_KEY", "unilab3d-model-protection-key-v1"
|
||||
).encode()
|
||||
self.assertEqual(result, key)
|
||||
|
||||
def test_custom_key(self):
|
||||
"""自定义密钥正确加解密。"""
|
||||
custom_key = b"custom-key-12345"
|
||||
original = b"test data for custom key"
|
||||
encrypted = _xor_transform(original, key=custom_key)
|
||||
decrypted = _xor_transform(encrypted, key=custom_key)
|
||||
self.assertEqual(decrypted, original)
|
||||
|
||||
def test_different_keys_produce_different_results(self):
|
||||
"""不同密钥产生不同加密结果。"""
|
||||
data = b"same data"
|
||||
key1 = b"key-one-is-here!"
|
||||
key2 = b"key-two-is-here!"
|
||||
self.assertNotEqual(_xor_transform(data, key1), _xor_transform(data, key2))
|
||||
|
||||
def test_binary_stl_header(self):
|
||||
"""二进制内容(模拟 STL 文件头)正确加解密。"""
|
||||
stl_header = b"\x00" * 80 + b"\x03\x00\x00\x00"
|
||||
encrypted = _xor_transform(stl_header)
|
||||
decrypted = _xor_transform(encrypted)
|
||||
self.assertEqual(decrypted, stl_header)
|
||||
|
||||
def test_large_data_roundtrip(self):
|
||||
"""大数据(1MB)加解密正确性。"""
|
||||
original = os.urandom(1024 * 1024)
|
||||
encrypted = _xor_transform(original)
|
||||
decrypted = _xor_transform(encrypted)
|
||||
self.assertEqual(decrypted, original)
|
||||
|
||||
def test_consistency_with_frontend_key(self):
|
||||
"""验证 Python 端与前端使用相同的默认密钥。"""
|
||||
frontend_key = b"unilab3d-model-protection-key-v1"
|
||||
data = b"cross-platform test data"
|
||||
encrypted = _xor_transform(data, key=frontend_key)
|
||||
# 用默认密钥解密(应一致)
|
||||
decrypted = _xor_transform(encrypted)
|
||||
self.assertEqual(decrypted, data)
|
||||
|
||||
|
||||
class TestEncryptExtensions(unittest.TestCase):
|
||||
"""加密文件扩展名配置测试。"""
|
||||
|
||||
def test_all_mesh_formats_in_encrypt_set(self):
|
||||
"""所有 mesh 格式都在加密扩展名集合中。"""
|
||||
expected = {".stl", ".dae", ".obj", ".fbx", ".gltf", ".glb"}
|
||||
self.assertEqual(_MESH_ENCRYPT_EXTENSIONS, expected)
|
||||
|
||||
def test_xml_formats_not_encrypted(self):
|
||||
"""XACRO/URDF/YAML 文件不加密。"""
|
||||
for ext in {".xacro", ".urdf", ".yaml", ".yml"}:
|
||||
self.assertNotIn(ext, _MESH_ENCRYPT_EXTENSIONS)
|
||||
|
||||
def test_encrypt_is_subset_of_model_extensions(self):
|
||||
"""加密扩展名是模型扩展名的子集。"""
|
||||
self.assertTrue(_MESH_ENCRYPT_EXTENSIONS.issubset(_MODEL_EXTENSIONS))
|
||||
|
||||
|
||||
class TestPutUploadEncryption(unittest.TestCase):
|
||||
"""_put_upload 中的条件加密测试。"""
|
||||
|
||||
@patch("unilabos.app.model_upload.requests.put")
|
||||
def test_stl_file_encrypted_before_upload(self, mock_put):
|
||||
"""STL 文件上传前自动 XOR 加密。"""
|
||||
from unilabos.app.model_upload import _put_upload
|
||||
|
||||
original_data = b"solid test\nfacet normal 0 0 1\n"
|
||||
with tempfile.NamedTemporaryFile(suffix=".stl", delete=False) as f:
|
||||
f.write(original_data)
|
||||
f.flush()
|
||||
tmp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
mock_put.return_value = MagicMock(status_code=200)
|
||||
mock_put.return_value.raise_for_status = MagicMock()
|
||||
_put_upload(tmp_path, "https://oss.example.com/upload")
|
||||
|
||||
uploaded_data = mock_put.call_args.kwargs.get("data")
|
||||
self.assertIsNotNone(uploaded_data)
|
||||
self.assertNotEqual(uploaded_data, original_data)
|
||||
# 解密后应恢复原始数据
|
||||
self.assertEqual(_xor_transform(uploaded_data), original_data)
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
@patch("unilabos.app.model_upload.requests.put")
|
||||
def test_xacro_file_not_encrypted(self, mock_put):
|
||||
"""XACRO 文件上传时不加密。"""
|
||||
from unilabos.app.model_upload import _put_upload
|
||||
|
||||
original_data = b'<?xml version="1.0"?><robot></robot>'
|
||||
with tempfile.NamedTemporaryFile(suffix=".xacro", delete=False) as f:
|
||||
f.write(original_data)
|
||||
f.flush()
|
||||
tmp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
mock_put.return_value = MagicMock(status_code=200)
|
||||
mock_put.return_value.raise_for_status = MagicMock()
|
||||
_put_upload(tmp_path, "https://oss.example.com/upload")
|
||||
|
||||
uploaded_data = mock_put.call_args.kwargs.get("data")
|
||||
self.assertEqual(uploaded_data, original_data)
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
@patch("unilabos.app.model_upload.requests.put")
|
||||
def test_all_mesh_formats_encrypted(self, mock_put):
|
||||
"""所有 mesh 格式上传前都加密。"""
|
||||
from unilabos.app.model_upload import _put_upload
|
||||
|
||||
original_data = b"test mesh binary data content"
|
||||
for ext in [".stl", ".dae", ".obj", ".fbx", ".gltf", ".glb"]:
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f:
|
||||
f.write(original_data)
|
||||
f.flush()
|
||||
tmp_path = Path(f.name)
|
||||
try:
|
||||
mock_put.reset_mock()
|
||||
mock_put.return_value = MagicMock(status_code=200)
|
||||
mock_put.return_value.raise_for_status = MagicMock()
|
||||
_put_upload(tmp_path, "https://oss.example.com/upload")
|
||||
|
||||
uploaded_data = mock_put.call_args.kwargs.get("data")
|
||||
self.assertNotEqual(uploaded_data, original_data, f"{ext} 文件应被加密")
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
@patch("unilabos.app.model_upload.requests.put")
|
||||
def test_uppercase_extension_encrypted(self, mock_put):
|
||||
"""大写扩展名 .STL 也被加密(大小写不敏感)。"""
|
||||
from unilabos.app.model_upload import _put_upload
|
||||
|
||||
original_data = b"uppercase ext test"
|
||||
with tempfile.NamedTemporaryFile(suffix=".STL", delete=False) as f:
|
||||
f.write(original_data)
|
||||
f.flush()
|
||||
tmp_path = Path(f.name)
|
||||
try:
|
||||
mock_put.return_value = MagicMock(status_code=200)
|
||||
mock_put.return_value.raise_for_status = MagicMock()
|
||||
_put_upload(tmp_path, "https://oss.example.com/upload")
|
||||
|
||||
uploaded_data = mock_put.call_args.kwargs.get("data")
|
||||
self.assertNotEqual(uploaded_data, original_data)
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
class TestDownloadFileDecryption(unittest.TestCase):
|
||||
"""_download_file 中的条件解密测试。"""
|
||||
|
||||
@patch("unilabos.app.model_upload.requests.get")
|
||||
def test_mesh_file_decrypted_on_download(self, mock_get):
|
||||
"""下载的 mesh 文件自动 XOR 解密后存本地。"""
|
||||
from unilabos.app.model_upload import _download_file
|
||||
|
||||
original_data = b"original stl content here"
|
||||
encrypted_data = _xor_transform(original_data)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = encrypted_data
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
local_path = Path(tmpdir) / "model.stl"
|
||||
_download_file("https://oss.example.com/model.stl", local_path)
|
||||
self.assertEqual(local_path.read_bytes(), original_data)
|
||||
|
||||
@patch("unilabos.app.model_upload.requests.get")
|
||||
def test_xacro_file_not_decrypted(self, mock_get):
|
||||
"""下载的 XACRO 文件不做解密处理。"""
|
||||
from unilabos.app.model_upload import _download_file
|
||||
|
||||
xml_data = b'<?xml version="1.0"?><robot></robot>'
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = xml_data
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
local_path = Path(tmpdir) / "macro.xacro"
|
||||
_download_file("https://oss.example.com/macro.xacro", local_path)
|
||||
self.assertEqual(local_path.read_bytes(), xml_data)
|
||||
|
||||
@patch("unilabos.app.model_upload.requests.get")
|
||||
def test_upload_download_roundtrip(self, mock_get):
|
||||
"""上传加密 → 下载解密的完整 round-trip。"""
|
||||
from unilabos.app.model_upload import _download_file
|
||||
|
||||
original_data = b"binary stl mesh \x00\xff\x80 special bytes"
|
||||
encrypted_data = _xor_transform(original_data)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = encrypted_data
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
local_path = Path(tmpdir) / "mesh.stl"
|
||||
_download_file("https://oss.example.com/mesh.stl", local_path)
|
||||
self.assertEqual(local_path.read_bytes(), original_data)
|
||||
|
||||
@patch("unilabos.app.model_upload.requests.get")
|
||||
def test_all_mesh_formats_decrypted(self, mock_get):
|
||||
"""所有 mesh 格式下载后都解密。"""
|
||||
from unilabos.app.model_upload import _download_file
|
||||
|
||||
original_data = b"mesh content for roundtrip"
|
||||
encrypted_data = _xor_transform(original_data)
|
||||
|
||||
for ext in [".stl", ".dae", ".obj", ".fbx", ".gltf", ".glb"]:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = encrypted_data
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
local_path = Path(tmpdir) / f"model{ext}"
|
||||
_download_file(f"https://oss.example.com/model{ext}", local_path)
|
||||
self.assertEqual(
|
||||
local_path.read_bytes(), original_data, f"{ext} 文件应被解密"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
170
tests/app/test_normalize_model.py
Normal file
170
tests/app/test_normalize_model.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""normalize_model_for_upload 单元测试"""
|
||||
|
||||
import unittest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from unilabos.app.register import normalize_model_for_upload
|
||||
|
||||
|
||||
class TestNormalizeModelForUpload(unittest.TestCase):
|
||||
"""测试 Registry YAML model 字段标准化"""
|
||||
|
||||
def test_empty_input(self):
|
||||
"""空 dict 直接返回"""
|
||||
self.assertEqual(normalize_model_for_upload({}), {})
|
||||
self.assertIsNone(normalize_model_for_upload(None))
|
||||
|
||||
def test_format_infer_xacro(self):
|
||||
"""自动从 path 后缀推断 format=xacro"""
|
||||
model = {
|
||||
"path": "https://oss.example.com/devices/arm/macro_device.xacro",
|
||||
"mesh": "arm_slider",
|
||||
"type": "device",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "xacro")
|
||||
|
||||
def test_format_infer_urdf(self):
|
||||
"""自动推断 format=urdf"""
|
||||
model = {"path": "https://example.com/robot.urdf", "type": "device"}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "urdf")
|
||||
|
||||
def test_format_infer_stl(self):
|
||||
"""自动推断 format=stl"""
|
||||
model = {"path": "https://example.com/part.stl"}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "stl")
|
||||
|
||||
def test_format_infer_gltf(self):
|
||||
"""自动推断 format=gltf(.gltf 和 .glb)"""
|
||||
for ext in [".gltf", ".glb"]:
|
||||
model = {"path": f"https://example.com/model{ext}"}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "gltf", f"failed for {ext}")
|
||||
|
||||
def test_format_not_overwritten(self):
|
||||
"""已有 format 字段时不覆盖"""
|
||||
model = {
|
||||
"path": "https://example.com/model.xacro",
|
||||
"format": "custom",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["format"], "custom")
|
||||
|
||||
def test_format_no_path(self):
|
||||
"""没有 path 时不推断 format"""
|
||||
model = {"mesh": "arm_slider", "type": "device"}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertNotIn("format", result)
|
||||
|
||||
def test_children_mesh_string_to_struct(self):
|
||||
"""将 children_mesh 字符串(旧格式)转为结构化对象"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"type": "resource",
|
||||
"children_mesh": "tip/meshes/tip.stl",
|
||||
"children_mesh_tf": [0.0045, 0.0045, 0, 0, 0, 1.57],
|
||||
"children_mesh_path": "https://oss.example.com/tip.stl",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
|
||||
cm = result["children_mesh"]
|
||||
self.assertIsInstance(cm, dict)
|
||||
self.assertEqual(cm["path"], "https://oss.example.com/tip.stl")
|
||||
self.assertEqual(cm["format"], "stl")
|
||||
self.assertTrue(cm["default_visible"])
|
||||
self.assertEqual(cm["local_offset"], [0.0045, 0.0045, 0])
|
||||
self.assertEqual(cm["local_rotation"], [0, 0, 1.57])
|
||||
|
||||
self.assertNotIn("children_mesh_tf", result)
|
||||
self.assertNotIn("children_mesh_path", result)
|
||||
|
||||
def test_children_mesh_no_oss_fallback(self):
|
||||
"""children_mesh 无 OSS URL 时 fallback 到本地路径"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": "plate_96/meshes/plate_96.stl",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
cm = result["children_mesh"]
|
||||
self.assertEqual(cm["path"], "plate_96/meshes/plate_96.stl")
|
||||
self.assertEqual(cm["format"], "stl")
|
||||
|
||||
def test_children_mesh_gltf_format(self):
|
||||
"""children_mesh .glb 文件推断 format=gltf"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": "meshes/child.glb",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["children_mesh"]["format"], "gltf")
|
||||
|
||||
def test_children_mesh_partial_tf(self):
|
||||
"""children_mesh_tf 只有 3 个值时只有 offset 无 rotation"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": "tip.stl",
|
||||
"children_mesh_tf": [0.01, 0.02, 0.03],
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
cm = result["children_mesh"]
|
||||
self.assertEqual(cm["local_offset"], [0.01, 0.02, 0.03])
|
||||
self.assertNotIn("local_rotation", cm)
|
||||
|
||||
def test_children_mesh_no_tf(self):
|
||||
"""children_mesh 无 tf 时不加 offset/rotation"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": "tip.stl",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
cm = result["children_mesh"]
|
||||
self.assertNotIn("local_offset", cm)
|
||||
self.assertNotIn("local_rotation", cm)
|
||||
|
||||
def test_children_mesh_already_dict(self):
|
||||
"""children_mesh 已经是 dict 时不重新映射"""
|
||||
model = {
|
||||
"path": "https://example.com/rack.xacro",
|
||||
"children_mesh": {
|
||||
"path": "https://example.com/tip.stl",
|
||||
"format": "stl",
|
||||
"default_visible": False,
|
||||
},
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
cm = result["children_mesh"]
|
||||
self.assertIsInstance(cm, dict)
|
||||
self.assertFalse(cm["default_visible"])
|
||||
|
||||
def test_original_not_mutated(self):
|
||||
"""原始 dict 不被修改"""
|
||||
original = {
|
||||
"path": "https://example.com/model.xacro",
|
||||
"mesh": "arm",
|
||||
}
|
||||
original_copy = {**original}
|
||||
normalize_model_for_upload(original)
|
||||
self.assertEqual(original, original_copy)
|
||||
|
||||
def test_preserves_existing_fields(self):
|
||||
"""所有原始字段都被保留"""
|
||||
model = {
|
||||
"path": "https://example.com/model.xacro",
|
||||
"mesh": "arm_slider",
|
||||
"type": "device",
|
||||
"mesh_tf": [0, 0, 0, 0, 0, 0],
|
||||
"custom_field": "should_survive",
|
||||
}
|
||||
result = normalize_model_for_upload(model)
|
||||
self.assertEqual(result["custom_field"], "should_survive")
|
||||
self.assertEqual(result["mesh_tf"], [0, 0, 0, 0, 0, 0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
939
tests/ros/test_joint_state_bridge.py
Normal file
939
tests/ros/test_joint_state_bridge.py
Normal file
@@ -0,0 +1,939 @@
|
||||
"""
|
||||
P1 关节数据 & 资源跟随桥接测试 — 全面覆盖 HostNode 关节回调 + resource_pose 回调的边缘 case。
|
||||
|
||||
不依赖 ROS2 运行时,通过 mock 模拟 msg 和 bridge。
|
||||
|
||||
测试分组:
|
||||
E1: JointRepublisher JSON 输出格式 (已修复 str→json.dumps)
|
||||
E2: 关节状态回调 — 从 /joint_states (JointState msg) 直接读取 name/position
|
||||
E3: 资源跟随 (resource_pose) — 夹爪抓取/释放/多资源
|
||||
E4: 联合流程 — 关节 + 资源一并通过 bridge 发送
|
||||
E5: Bridge 调用验证
|
||||
E6: 同类型设备多实例 — 重复关节名场景
|
||||
E7: 吞吐优化 — 死区过滤、抑频、增量 resource_poses
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
# ==================== 辅助: 模拟 JointState msg ====================
|
||||
|
||||
|
||||
def _make_joint_state_msg(names: list, positions: list, velocities=None, efforts=None):
|
||||
"""构造模拟的 sensor_msgs/JointState 消息(不依赖 ROS2)"""
|
||||
msg = SimpleNamespace()
|
||||
msg.name = names
|
||||
msg.position = positions
|
||||
msg.velocity = velocities or [0.0] * len(names)
|
||||
msg.effort = efforts or [0.0] * len(names)
|
||||
return msg
|
||||
|
||||
|
||||
def _make_string_msg(data: str):
|
||||
"""构造模拟的 std_msgs/String 消息"""
|
||||
msg = SimpleNamespace()
|
||||
msg.data = data
|
||||
return msg
|
||||
|
||||
|
||||
# ==================== 辅助: 提取 HostNode 核心逻辑用于隔离测试 ====================
|
||||
|
||||
|
||||
class JointBridgeSimulator:
|
||||
"""
|
||||
模拟 HostNode 的关节桥接核心逻辑(提取自 host_node.py),
|
||||
不依赖 ROS2 Node、subscription 等基础设施。
|
||||
|
||||
包含吞吐优化逻辑:
|
||||
- 死区过滤 (dead band): 关节变化 < 阈值时不发送
|
||||
- 抑频 (throttle): 限制最大发送频率
|
||||
- 增量 resource_poses: 仅在变化时附带
|
||||
"""
|
||||
|
||||
JOINT_DEAD_BAND: float = 1e-4
|
||||
JOINT_MIN_INTERVAL: float = 0.05 # 秒
|
||||
|
||||
def __init__(self, device_uuid_map: Dict[str, str],
|
||||
dead_band: Optional[float] = None,
|
||||
min_interval: Optional[float] = None):
|
||||
self.device_uuid_map = device_uuid_map
|
||||
self._device_ids_sorted = sorted(device_uuid_map.keys(), key=len, reverse=True)
|
||||
self._resource_poses: Dict[str, str] = {}
|
||||
self._resource_poses_dirty: bool = False
|
||||
self._last_joint_values: Dict[str, float] = {}
|
||||
self._last_send_time: float = -float("inf") # 确保首条消息总是通过
|
||||
# 允许测试覆盖优化参数
|
||||
if dead_band is not None:
|
||||
self.JOINT_DEAD_BAND = dead_band
|
||||
if min_interval is not None:
|
||||
self.JOINT_MIN_INTERVAL = min_interval
|
||||
|
||||
def resource_pose_callback(self, msg) -> None:
|
||||
"""模拟 HostNode._resource_pose_callback(含变化检测)"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return
|
||||
if not isinstance(data, dict) or not data:
|
||||
return
|
||||
has_change = False
|
||||
for k, v in data.items():
|
||||
if self._resource_poses.get(k) != v:
|
||||
has_change = True
|
||||
break
|
||||
if has_change:
|
||||
self._resource_poses.update(data)
|
||||
self._resource_poses_dirty = True
|
||||
|
||||
def joint_state_callback(self, msg, now: Optional[float] = None) -> dict:
|
||||
"""
|
||||
模拟 HostNode._joint_state_callback 核心逻辑(含优化)。
|
||||
now 参数允许测试控制时间。
|
||||
返回 {device_id: {"node_uuid": ..., "joint_states": {...}, "resource_poses": {...}}}。
|
||||
返回 {} 表示被优化过滤。
|
||||
"""
|
||||
names = list(msg.name)
|
||||
positions = list(msg.position)
|
||||
if not names or len(names) != len(positions):
|
||||
return {}
|
||||
|
||||
if now is None:
|
||||
now = time.time()
|
||||
resource_dirty = self._resource_poses_dirty
|
||||
|
||||
# 抑频检查
|
||||
if not resource_dirty and (now - self._last_send_time) < self.JOINT_MIN_INTERVAL:
|
||||
return {}
|
||||
|
||||
# 死区过滤
|
||||
has_significant_change = False
|
||||
for name, pos in zip(names, positions):
|
||||
last_val = self._last_joint_values.get(name)
|
||||
if last_val is None or abs(float(pos) - last_val) >= self.JOINT_DEAD_BAND:
|
||||
has_significant_change = True
|
||||
break
|
||||
|
||||
if not has_significant_change and not resource_dirty:
|
||||
return {}
|
||||
|
||||
# 更新状态
|
||||
for name, pos in zip(names, positions):
|
||||
self._last_joint_values[name] = float(pos)
|
||||
self._last_send_time = now
|
||||
|
||||
# 按设备 ID 分组关节数据
|
||||
device_joints: Dict[str, Dict[str, float]] = {}
|
||||
for name, pos in zip(names, positions):
|
||||
matched_device = None
|
||||
for device_id in self._device_ids_sorted:
|
||||
if name.startswith(device_id + "_"):
|
||||
matched_device = device_id
|
||||
break
|
||||
if matched_device:
|
||||
if matched_device not in device_joints:
|
||||
device_joints[matched_device] = {}
|
||||
device_joints[matched_device][name] = float(pos)
|
||||
elif len(self.device_uuid_map) == 1:
|
||||
fallback_id = self._device_ids_sorted[0]
|
||||
if fallback_id not in device_joints:
|
||||
device_joints[fallback_id] = {}
|
||||
device_joints[fallback_id][name] = float(pos)
|
||||
|
||||
# 构建设备级 resource_poses(仅 dirty 时附带)
|
||||
device_resource_poses: Dict[str, Dict[str, str]] = {}
|
||||
if resource_dirty:
|
||||
for resource_id, link_name in self._resource_poses.items():
|
||||
matched_device = None
|
||||
for device_id in self._device_ids_sorted:
|
||||
if link_name.startswith(device_id + "_"):
|
||||
matched_device = device_id
|
||||
break
|
||||
if matched_device:
|
||||
if matched_device not in device_resource_poses:
|
||||
device_resource_poses[matched_device] = {}
|
||||
device_resource_poses[matched_device][resource_id] = link_name
|
||||
elif len(self.device_uuid_map) == 1:
|
||||
fallback_id = self._device_ids_sorted[0]
|
||||
if fallback_id not in device_resource_poses:
|
||||
device_resource_poses[fallback_id] = {}
|
||||
device_resource_poses[fallback_id][resource_id] = link_name
|
||||
self._resource_poses_dirty = False
|
||||
|
||||
result = {}
|
||||
for device_id, joint_states in device_joints.items():
|
||||
node_uuid = self.device_uuid_map.get(device_id)
|
||||
if not node_uuid:
|
||||
continue
|
||||
result[device_id] = {
|
||||
"node_uuid": node_uuid,
|
||||
"joint_states": joint_states,
|
||||
"resource_poses": device_resource_poses.get(device_id, {}),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
# 功能测试中禁用优化(dead_band=0, min_interval=0),确保逻辑正确性
|
||||
def _make_sim(device_uuid_map: Dict[str, str]) -> JointBridgeSimulator:
|
||||
"""创建禁用吞吐优化的模拟器(用于功能正确性测试)"""
|
||||
return JointBridgeSimulator(device_uuid_map, dead_band=0.0, min_interval=0.0)
|
||||
|
||||
|
||||
# ==================== E1: JointRepublisher JSON 输出 ====================
|
||||
|
||||
|
||||
class TestJointRepublisherFormat:
|
||||
"""验证 JointRepublisher 输出标准 JSON(双引号)而非 Python repr(单引号)"""
|
||||
|
||||
def test_output_is_valid_json(self):
|
||||
"""str() 产生单引号,json.dumps() 产生双引号"""
|
||||
joint_dict = {
|
||||
"name": ["joint1", "joint2"],
|
||||
"position": [0.1, 0.2],
|
||||
"velocity": [0.0, 0.0],
|
||||
"effort": [0.0, 0.0],
|
||||
}
|
||||
result = json.dumps(joint_dict)
|
||||
parsed = json.loads(result)
|
||||
assert parsed["name"] == ["joint1", "joint2"]
|
||||
assert parsed["position"] == [0.1, 0.2]
|
||||
assert "'" not in result
|
||||
|
||||
def test_str_produces_invalid_json(self):
|
||||
"""对比: str() 不是合法 JSON"""
|
||||
joint_dict = {"name": ["joint1"], "position": [0.1]}
|
||||
result = str(joint_dict)
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
json.loads(result)
|
||||
|
||||
|
||||
# ==================== E2: 关节状态回调(JointState msg 直接读取)====================
|
||||
|
||||
|
||||
class TestJointStateCallback:
|
||||
"""测试从 JointState msg 直接读取 name/position 的分组逻辑"""
|
||||
|
||||
def test_single_device_simple(self):
|
||||
"""单设备,关节名有设备前缀"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
msg = _make_joint_state_msg(
|
||||
["panda_joint1", "panda_joint2"], [0.5, 1.0]
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert "panda" in result
|
||||
assert result["panda"]["joint_states"]["panda_joint1"] == 0.5
|
||||
assert result["panda"]["joint_states"]["panda_joint2"] == 1.0
|
||||
|
||||
def test_single_device_no_prefix_fallback(self):
|
||||
"""单设备,关节名无设备前缀 → 应归入唯一设备"""
|
||||
sim = _make_sim({"robot1": "uuid-r1"})
|
||||
msg = _make_joint_state_msg(["joint_a", "joint_b"], [1.0, 2.0])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert "robot1" in result
|
||||
assert result["robot1"]["joint_states"]["joint_a"] == 1.0
|
||||
assert result["robot1"]["joint_states"]["joint_b"] == 2.0
|
||||
|
||||
def test_multi_device_distinct_prefixes(self):
|
||||
"""多设备,不同前缀,正确分组"""
|
||||
sim = _make_sim({"arm1": "uuid-arm1", "arm2": "uuid-arm2"})
|
||||
msg = _make_joint_state_msg(
|
||||
["arm1_j1", "arm1_j2", "arm2_j1", "arm2_j2"],
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["arm1"]["joint_states"]["arm1_j1"] == 0.1
|
||||
assert result["arm1"]["joint_states"]["arm1_j2"] == 0.2
|
||||
assert result["arm2"]["joint_states"]["arm2_j1"] == 0.3
|
||||
assert result["arm2"]["joint_states"]["arm2_j2"] == 0.4
|
||||
|
||||
def test_ambiguous_prefix_longest_wins(self):
|
||||
"""前缀歧义: arm 和 arm_left — 最长前缀优先"""
|
||||
sim = _make_sim({"arm": "uuid-arm", "arm_left": "uuid-arm-left"})
|
||||
msg = _make_joint_state_msg(
|
||||
["arm_j1", "arm_left_j1", "arm_left_j2"],
|
||||
[0.1, 0.2, 0.3],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["arm"]["joint_states"]["arm_j1"] == 0.1
|
||||
assert result["arm_left"]["joint_states"]["arm_left_j1"] == 0.2
|
||||
assert result["arm_left"]["joint_states"]["arm_left_j2"] == 0.3
|
||||
|
||||
def test_multi_device_unmatched_joints_dropped(self):
|
||||
"""多设备时,无法匹配前缀的关节应被丢弃(不 fallback)"""
|
||||
sim = _make_sim({"arm1": "uuid-arm1", "arm2": "uuid-arm2"})
|
||||
msg = _make_joint_state_msg(
|
||||
["arm1_j1", "unknown_j1"],
|
||||
[0.1, 0.9],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["arm1"]["joint_states"]["arm1_j1"] == 0.1
|
||||
for device_id, data in result.items():
|
||||
assert "unknown_j1" not in data["joint_states"]
|
||||
|
||||
def test_empty_names(self):
|
||||
"""空 name 列表"""
|
||||
sim = _make_sim({"dev": "uuid-dev"})
|
||||
msg = _make_joint_state_msg([], [])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result == {}
|
||||
|
||||
def test_mismatched_lengths(self):
|
||||
"""name 和 position 长度不一致"""
|
||||
sim = _make_sim({"dev": "uuid-dev"})
|
||||
msg = _make_joint_state_msg(["j1", "j2"], [0.1])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result == {}
|
||||
|
||||
def test_no_devices(self):
|
||||
"""无设备 UUID 映射"""
|
||||
sim = _make_sim({})
|
||||
msg = _make_joint_state_msg(["j1"], [0.1])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result == {}
|
||||
|
||||
def test_numeric_prefix_device_ids(self):
|
||||
"""数字化设备 ID (如 deck1, deck12) — deck12_slot1 不应匹配 deck1"""
|
||||
sim = _make_sim({"deck1": "uuid-d1", "deck12": "uuid-d12"})
|
||||
msg = _make_joint_state_msg(
|
||||
["deck1_slot1", "deck12_slot1"],
|
||||
[1.0, 2.0],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["deck1"]["joint_states"]["deck1_slot1"] == 1.0
|
||||
assert result["deck12"]["joint_states"]["deck12_slot1"] == 2.0
|
||||
|
||||
def test_position_float_conversion(self):
|
||||
"""position 值应强制转为 float(即使输入为 int)"""
|
||||
sim = _make_sim({"arm": "uuid-arm"})
|
||||
msg = _make_joint_state_msg(["arm_j1"], [1])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["arm"]["joint_states"]["arm_j1"] == 1.0
|
||||
assert isinstance(result["arm"]["joint_states"]["arm_j1"], float)
|
||||
|
||||
def test_node_uuid_in_result(self):
|
||||
"""结果中应携带正确的 node_uuid"""
|
||||
sim = _make_sim({"panda": "uuid-panda-123"})
|
||||
msg = _make_joint_state_msg(["panda_j1"], [0.5])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["panda"]["node_uuid"] == "uuid-panda-123"
|
||||
|
||||
def test_device_with_no_uuid_skipped(self):
|
||||
"""device_uuid_map 中存在映射但值为空 → 跳过"""
|
||||
sim = _make_sim({"arm": ""})
|
||||
msg = _make_joint_state_msg(["arm_j1"], [0.5])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result == {}
|
||||
|
||||
def test_many_joints_single_device(self):
|
||||
"""单设备大量关节(如 7-DOF arm)"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
names = [f"panda_joint{i}" for i in range(1, 8)]
|
||||
positions = [float(i) * 0.1 for i in range(1, 8)]
|
||||
msg = _make_joint_state_msg(names, positions)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert len(result["panda"]["joint_states"]) == 7
|
||||
assert result["panda"]["joint_states"]["panda_joint7"] == pytest.approx(0.7)
|
||||
|
||||
def test_duplicate_joint_names_last_wins(self):
|
||||
"""同类型设备多个实例时,如果关节名完全重复(bug 场景),后出现的值覆盖前者"""
|
||||
sim = _make_sim({"dev": "uuid-dev"})
|
||||
msg = _make_joint_state_msg(["dev_j1", "dev_j1"], [1.0, 2.0])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["dev"]["joint_states"]["dev_j1"] == 2.0
|
||||
|
||||
def test_negative_positions(self):
|
||||
"""关节角度为负数"""
|
||||
sim = _make_sim({"arm": "uuid-arm"})
|
||||
msg = _make_joint_state_msg(["arm_j1", "arm_j2"], [-1.57, -3.14])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["arm"]["joint_states"]["arm_j1"] == pytest.approx(-1.57)
|
||||
assert result["arm"]["joint_states"]["arm_j2"] == pytest.approx(-3.14)
|
||||
|
||||
|
||||
# ==================== E3: 资源跟随 (resource_pose) ====================
|
||||
|
||||
|
||||
class TestResourcePoseCallback:
|
||||
"""测试 resource_pose 回调 — 夹爪抓取/释放/多资源"""
|
||||
|
||||
def test_single_resource_attach(self):
|
||||
"""单个资源挂载到夹爪 link"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
msg = _make_string_msg(json.dumps({"plate_1": "panda_gripper_link"}))
|
||||
sim.resource_pose_callback(msg)
|
||||
assert sim._resource_poses == {"plate_1": "panda_gripper_link"}
|
||||
assert sim._resource_poses_dirty is True
|
||||
|
||||
def test_multiple_resource_attach(self):
|
||||
"""多个资源同时挂载到不同 link"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
msg = _make_string_msg(json.dumps({
|
||||
"plate_1": "panda_gripper_link",
|
||||
"tip_rack": "panda_deck_link",
|
||||
}))
|
||||
sim.resource_pose_callback(msg)
|
||||
assert sim._resource_poses["plate_1"] == "panda_gripper_link"
|
||||
assert sim._resource_poses["tip_rack"] == "panda_deck_link"
|
||||
|
||||
def test_incremental_update(self):
|
||||
"""增量更新:新消息合并到已有状态"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_1": "panda_deck_link"})))
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_2": "panda_gripper_link"})))
|
||||
assert len(sim._resource_poses) == 2
|
||||
assert sim._resource_poses["plate_1"] == "panda_deck_link"
|
||||
assert sim._resource_poses["plate_2"] == "panda_gripper_link"
|
||||
|
||||
def test_resource_reattach(self):
|
||||
"""资源从 deck 移动到 gripper(抓取操作)"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_1": "panda_deck_link"})))
|
||||
assert sim._resource_poses["plate_1"] == "panda_deck_link"
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_1": "panda_gripper_link"})))
|
||||
assert sim._resource_poses["plate_1"] == "panda_gripper_link"
|
||||
|
||||
def test_resource_release_back_to_world(self):
|
||||
"""释放资源回到 world"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_1": "panda_gripper_link"})))
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_1": "world"})))
|
||||
assert sim._resource_poses["plate_1"] == "world"
|
||||
|
||||
def test_empty_dict_heartbeat_no_dirty(self):
|
||||
"""空 dict(心跳包)不标记 dirty"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_1": "panda_link"})))
|
||||
sim._resource_poses_dirty = False # 重置
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({})))
|
||||
assert sim._resource_poses_dirty is False # 空 dict 不应标记 dirty
|
||||
|
||||
def test_same_value_no_dirty(self):
|
||||
"""重复发送相同值不应标记 dirty"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_1": "panda_link"})))
|
||||
sim._resource_poses_dirty = False
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_1": "panda_link"})))
|
||||
assert sim._resource_poses_dirty is False
|
||||
|
||||
def test_invalid_json_ignored(self):
|
||||
"""非法 JSON 消息不影响状态"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate_1": "panda_link"})))
|
||||
sim.resource_pose_callback(_make_string_msg("not valid json {{{"))
|
||||
assert sim._resource_poses["plate_1"] == "panda_link"
|
||||
|
||||
def test_non_dict_json_ignored(self):
|
||||
"""JSON 但不是 dict(如 list)应被忽略"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps(["not", "a", "dict"])))
|
||||
assert sim._resource_poses == {}
|
||||
|
||||
def test_python_repr_ignored(self):
|
||||
"""Python repr 格式(单引号)应被忽略"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg("{'plate_1': 'panda_link'}"))
|
||||
assert sim._resource_poses == {}
|
||||
|
||||
def test_multi_device_resource_attach(self):
|
||||
"""多设备场景:不同设备的 link 挂载不同资源"""
|
||||
sim = _make_sim({"arm1": "uuid-arm1", "arm2": "uuid-arm2"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_A": "arm1_gripper_link",
|
||||
"plate_B": "arm2_gripper_link",
|
||||
})))
|
||||
assert sim._resource_poses["plate_A"] == "arm1_gripper_link"
|
||||
assert sim._resource_poses["plate_B"] == "arm2_gripper_link"
|
||||
|
||||
|
||||
# ==================== E4: 联合流程 — 关节 + 资源一并通过 bridge ====================
|
||||
|
||||
|
||||
class TestJointWithResourcePoses:
|
||||
"""测试关节状态回调时,resource_poses 被正确按设备分组并包含在结果中"""
|
||||
|
||||
def test_single_device_joint_with_resource(self):
|
||||
"""单设备:关节更新时携带已挂载的资源"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_1": "panda_gripper_link",
|
||||
})))
|
||||
msg = _make_joint_state_msg(["panda_j1", "panda_j2"], [0.5, 1.0])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["panda"]["resource_poses"] == {"plate_1": "panda_gripper_link"}
|
||||
|
||||
def test_single_device_no_resource(self):
|
||||
"""单设备:无资源挂载时 resource_poses 为空 dict"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
msg = _make_joint_state_msg(["panda_j1"], [0.5])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["panda"]["resource_poses"] == {}
|
||||
|
||||
def test_multi_device_resource_routing(self):
|
||||
"""多设备:资源按 link 前缀路由到正确设备"""
|
||||
sim = _make_sim({"arm1": "uuid-arm1", "arm2": "uuid-arm2"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_A": "arm1_gripper_link",
|
||||
"plate_B": "arm2_gripper_link",
|
||||
"tube_1": "arm1_tool_link",
|
||||
})))
|
||||
msg = _make_joint_state_msg(
|
||||
["arm1_j1", "arm2_j1"],
|
||||
[0.1, 0.2],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["arm1"]["resource_poses"] == {
|
||||
"plate_A": "arm1_gripper_link",
|
||||
"tube_1": "arm1_tool_link",
|
||||
}
|
||||
assert result["arm2"]["resource_poses"] == {"plate_B": "arm2_gripper_link"}
|
||||
|
||||
def test_resource_on_world_frame_not_routed(self):
|
||||
"""资源挂在 world frame(已释放)— 多设备时无法匹配任何设备前缀"""
|
||||
sim = _make_sim({"arm1": "uuid-arm1", "arm2": "uuid-arm2"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_A": "world",
|
||||
})))
|
||||
msg = _make_joint_state_msg(["arm1_j1"], [0.1])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["arm1"]["resource_poses"] == {}
|
||||
|
||||
def test_resource_world_frame_single_device_fallback(self):
|
||||
"""单设备时 world frame 的资源走 fallback"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_A": "world",
|
||||
})))
|
||||
msg = _make_joint_state_msg(["panda_j1"], [0.1])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["panda"]["resource_poses"] == {"plate_A": "world"}
|
||||
|
||||
def test_grab_and_move_sequence(self):
|
||||
"""完整夹取序列: 资源在 deck → gripper 抓取 → arm 移动 → 放下"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
|
||||
# 初始: plate 在 deck
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_1": "panda_deck_third_link",
|
||||
})))
|
||||
|
||||
msg = _make_joint_state_msg(
|
||||
["panda_j1", "panda_j2", "panda_j3"],
|
||||
[0.0, -0.5, 1.0],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["panda"]["resource_poses"]["plate_1"] == "panda_deck_third_link"
|
||||
|
||||
# 抓取: plate 从 deck → gripper
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_1": "panda_gripper_link",
|
||||
})))
|
||||
|
||||
msg = _make_joint_state_msg(
|
||||
["panda_j1", "panda_j2", "panda_j3"],
|
||||
[1.57, 0.0, -0.5],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["panda"]["resource_poses"]["plate_1"] == "panda_gripper_link"
|
||||
assert result["panda"]["joint_states"]["panda_j1"] == pytest.approx(1.57)
|
||||
|
||||
# 放下: plate 从 gripper → 目标 deck
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_1": "panda_deck_first_link",
|
||||
})))
|
||||
|
||||
msg = _make_joint_state_msg(
|
||||
["panda_j1", "panda_j2", "panda_j3"],
|
||||
[0.0, 0.0, 0.0],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["panda"]["resource_poses"]["plate_1"] == "panda_deck_first_link"
|
||||
|
||||
def test_simultaneous_grab_multiple_resources(self):
|
||||
"""同时持有多个资源(如双夹爪)"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_1": "panda_left_gripper",
|
||||
"plate_2": "panda_right_gripper",
|
||||
"tip_rack": "panda_deck_link",
|
||||
})))
|
||||
msg = _make_joint_state_msg(["panda_j1"], [0.5])
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert len(result["panda"]["resource_poses"]) == 3
|
||||
|
||||
def test_resource_with_ambiguous_link_prefix(self):
|
||||
"""link 前缀歧义: arm_left_gripper 应匹配 arm_left 而非 arm"""
|
||||
sim = _make_sim({"arm": "uuid-arm", "arm_left": "uuid-arm-left"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_A": "arm_gripper_link",
|
||||
"plate_B": "arm_left_gripper_link",
|
||||
})))
|
||||
msg = _make_joint_state_msg(
|
||||
["arm_j1", "arm_left_j1"],
|
||||
[0.1, 0.2],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["arm"]["resource_poses"] == {"plate_A": "arm_gripper_link"}
|
||||
assert result["arm_left"]["resource_poses"] == {"plate_B": "arm_left_gripper_link"}
|
||||
|
||||
|
||||
# ==================== E5: Bridge 调用验证 ====================
|
||||
|
||||
|
||||
class TestBridgeCalls:
|
||||
"""验证完整桥接流: callback → bridge.publish_joint_state 调用"""
|
||||
|
||||
def test_bridge_called_per_device(self):
|
||||
"""每个设备调用一次 publish_joint_state"""
|
||||
device_uuid_map = {"arm1": "uuid-111", "arm2": "uuid-222"}
|
||||
sim = _make_sim(device_uuid_map)
|
||||
bridge = MagicMock()
|
||||
bridge.publish_joint_state = MagicMock()
|
||||
|
||||
msg = _make_joint_state_msg(
|
||||
["arm1_j1", "arm2_j1"],
|
||||
[1.0, 2.0],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
|
||||
for device_id, data in result.items():
|
||||
bridge.publish_joint_state(
|
||||
data["node_uuid"], data["joint_states"], data["resource_poses"]
|
||||
)
|
||||
|
||||
assert bridge.publish_joint_state.call_count == 2
|
||||
call_uuids = {c[0][0] for c in bridge.publish_joint_state.call_args_list}
|
||||
assert call_uuids == {"uuid-111", "uuid-222"}
|
||||
|
||||
def test_bridge_called_with_resource_poses(self):
|
||||
"""bridge 调用时携带 resource_poses"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_1": "panda_gripper_link",
|
||||
})))
|
||||
|
||||
bridge = MagicMock()
|
||||
msg = _make_joint_state_msg(["panda_j1"], [0.5])
|
||||
result = sim.joint_state_callback(msg)
|
||||
|
||||
for device_id, data in result.items():
|
||||
bridge.publish_joint_state(
|
||||
data["node_uuid"], data["joint_states"], data["resource_poses"]
|
||||
)
|
||||
|
||||
bridge.publish_joint_state.assert_called_once_with(
|
||||
"uuid-panda",
|
||||
{"panda_j1": 0.5},
|
||||
{"plate_1": "panda_gripper_link"},
|
||||
)
|
||||
|
||||
def test_bridge_no_call_for_empty_joints(self):
|
||||
"""无关节数据时不调用 bridge"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
bridge = MagicMock()
|
||||
|
||||
msg = _make_joint_state_msg([], [])
|
||||
result = sim.joint_state_callback(msg)
|
||||
|
||||
for device_id, data in result.items():
|
||||
bridge.publish_joint_state(
|
||||
data["node_uuid"], data["joint_states"], data["resource_poses"]
|
||||
)
|
||||
|
||||
bridge.publish_joint_state.assert_not_called()
|
||||
|
||||
def test_bridge_resource_poses_empty_when_no_resources(self):
|
||||
"""无资源挂载时,resource_poses 参数为空 dict"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
bridge = MagicMock()
|
||||
|
||||
msg = _make_joint_state_msg(["panda_j1"], [0.5])
|
||||
result = sim.joint_state_callback(msg)
|
||||
|
||||
for device_id, data in result.items():
|
||||
bridge.publish_joint_state(
|
||||
data["node_uuid"], data["joint_states"], data["resource_poses"]
|
||||
)
|
||||
|
||||
bridge.publish_joint_state.assert_called_once_with(
|
||||
"uuid-panda",
|
||||
{"panda_j1": 0.5},
|
||||
{},
|
||||
)
|
||||
|
||||
def test_multi_bridge_all_called(self):
|
||||
"""多个 bridge 都应被调用"""
|
||||
sim = _make_sim({"arm": "uuid-arm"})
|
||||
bridges = [MagicMock(), MagicMock()]
|
||||
|
||||
msg = _make_joint_state_msg(["arm_j1"], [0.5])
|
||||
result = sim.joint_state_callback(msg)
|
||||
|
||||
for device_id, data in result.items():
|
||||
for bridge in bridges:
|
||||
bridge.publish_joint_state(
|
||||
data["node_uuid"], data["joint_states"], data["resource_poses"]
|
||||
)
|
||||
|
||||
for bridge in bridges:
|
||||
bridge.publish_joint_state.assert_called_once()
|
||||
|
||||
|
||||
# ==================== E6: 同类型设备多个实例 — 重复关节名场景 ====================
|
||||
|
||||
|
||||
class TestDuplicateDeviceTypes:
|
||||
"""
|
||||
多个同类型设备(如 2 个 OT-2 移液器),关节名格式为 {device_id}_{joint_name}。
|
||||
设备 ID 不同(如 ot2_left, ot2_right),但底层关节名相同(如 pipette_j1)。
|
||||
"""
|
||||
|
||||
def test_same_type_different_id(self):
|
||||
"""同类型设备不同 ID"""
|
||||
sim = _make_sim({
|
||||
"ot2_left": "uuid-ot2-left",
|
||||
"ot2_right": "uuid-ot2-right",
|
||||
})
|
||||
msg = _make_joint_state_msg(
|
||||
["ot2_left_pipette_j1", "ot2_left_pipette_j2",
|
||||
"ot2_right_pipette_j1", "ot2_right_pipette_j2"],
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["ot2_left"]["joint_states"]["ot2_left_pipette_j1"] == 0.1
|
||||
assert result["ot2_left"]["joint_states"]["ot2_left_pipette_j2"] == 0.2
|
||||
assert result["ot2_right"]["joint_states"]["ot2_right_pipette_j1"] == 0.3
|
||||
assert result["ot2_right"]["joint_states"]["ot2_right_pipette_j2"] == 0.4
|
||||
|
||||
def test_same_type_with_resources_routed_correctly(self):
|
||||
"""同类型设备各自抓取资源,按 link 前缀正确路由"""
|
||||
sim = _make_sim({
|
||||
"ot2_left": "uuid-ot2-left",
|
||||
"ot2_right": "uuid-ot2-right",
|
||||
})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({
|
||||
"plate_A": "ot2_left_gripper",
|
||||
"plate_B": "ot2_right_gripper",
|
||||
})))
|
||||
msg = _make_joint_state_msg(
|
||||
["ot2_left_j1", "ot2_right_j1"],
|
||||
[0.5, 0.6],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["ot2_left"]["resource_poses"] == {"plate_A": "ot2_left_gripper"}
|
||||
assert result["ot2_right"]["resource_poses"] == {"plate_B": "ot2_right_gripper"}
|
||||
|
||||
def test_numbered_devices_no_confusion(self):
|
||||
"""编号设备: robot1 不应匹配 robot10 的关节"""
|
||||
sim = _make_sim({
|
||||
"robot1": "uuid-r1",
|
||||
"robot10": "uuid-r10",
|
||||
})
|
||||
msg = _make_joint_state_msg(
|
||||
["robot1_j1", "robot10_j1"],
|
||||
[1.0, 10.0],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert result["robot1"]["joint_states"]["robot1_j1"] == 1.0
|
||||
assert result["robot10"]["joint_states"]["robot10_j1"] == 10.0
|
||||
|
||||
def test_three_same_type_devices(self):
|
||||
"""三个同类型设备"""
|
||||
sim = _make_sim({
|
||||
"pump_a": "uuid-pa",
|
||||
"pump_b": "uuid-pb",
|
||||
"pump_c": "uuid-pc",
|
||||
})
|
||||
msg = _make_joint_state_msg(
|
||||
["pump_a_flow", "pump_b_flow", "pump_c_flow",
|
||||
"pump_a_pressure", "pump_b_pressure"],
|
||||
[1.0, 2.0, 3.0, 0.1, 0.2],
|
||||
)
|
||||
result = sim.joint_state_callback(msg)
|
||||
assert len(result["pump_a"]["joint_states"]) == 2
|
||||
assert len(result["pump_b"]["joint_states"]) == 2
|
||||
assert len(result["pump_c"]["joint_states"]) == 1
|
||||
|
||||
|
||||
# ==================== E7: 吞吐优化测试 ====================
|
||||
|
||||
|
||||
class TestThroughputOptimizations:
|
||||
"""测试死区过滤、抑频、增量 resource_poses 等优化行为"""
|
||||
|
||||
# --- 死区过滤 (Dead Band) ---
|
||||
|
||||
def test_dead_band_filters_tiny_change(self):
|
||||
"""关节变化小于死区阈值 → 被过滤"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=0.01, min_interval=0.0)
|
||||
msg1 = _make_joint_state_msg(["arm_j1"], [1.0])
|
||||
result1 = sim.joint_state_callback(msg1, now=0.0)
|
||||
assert "arm" in result1
|
||||
|
||||
# 微小变化 (0.001 < 0.01 死区)
|
||||
msg2 = _make_joint_state_msg(["arm_j1"], [1.001])
|
||||
result2 = sim.joint_state_callback(msg2, now=1.0)
|
||||
assert result2 == {}
|
||||
|
||||
def test_dead_band_passes_significant_change(self):
|
||||
"""关节变化大于死区阈值 → 通过"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=0.01, min_interval=0.0)
|
||||
msg1 = _make_joint_state_msg(["arm_j1"], [1.0])
|
||||
sim.joint_state_callback(msg1, now=0.0)
|
||||
|
||||
msg2 = _make_joint_state_msg(["arm_j1"], [1.05])
|
||||
result2 = sim.joint_state_callback(msg2, now=1.0)
|
||||
assert "arm" in result2
|
||||
assert result2["arm"]["joint_states"]["arm_j1"] == pytest.approx(1.05)
|
||||
|
||||
def test_dead_band_first_message_always_passes(self):
|
||||
"""首次消息总是通过(无历史值)"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=1000.0, min_interval=0.0)
|
||||
msg = _make_joint_state_msg(["arm_j1"], [0.001])
|
||||
result = sim.joint_state_callback(msg, now=0.0)
|
||||
assert "arm" in result
|
||||
|
||||
def test_dead_band_any_joint_change_triggers(self):
|
||||
"""多关节中只要有一个超过死区就全部发送"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=0.01, min_interval=0.0)
|
||||
msg1 = _make_joint_state_msg(["arm_j1", "arm_j2"], [1.0, 2.0])
|
||||
sim.joint_state_callback(msg1, now=0.0)
|
||||
|
||||
# j1 微变化,j2 大变化
|
||||
msg2 = _make_joint_state_msg(["arm_j1", "arm_j2"], [1.001, 2.5])
|
||||
result2 = sim.joint_state_callback(msg2, now=1.0)
|
||||
assert "arm" in result2
|
||||
# 两个关节的值都应包含在结果中
|
||||
assert result2["arm"]["joint_states"]["arm_j1"] == pytest.approx(1.001)
|
||||
assert result2["arm"]["joint_states"]["arm_j2"] == pytest.approx(2.5)
|
||||
|
||||
# --- 抑频 (Throttle) ---
|
||||
|
||||
def test_throttle_filters_rapid_messages(self):
|
||||
"""发送间隔内的消息被过滤"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=0.0, min_interval=0.1)
|
||||
msg1 = _make_joint_state_msg(["arm_j1"], [1.0])
|
||||
result1 = sim.joint_state_callback(msg1, now=0.0)
|
||||
assert "arm" in result1
|
||||
|
||||
# 0.05s < 0.1s 间隔
|
||||
msg2 = _make_joint_state_msg(["arm_j1"], [2.0])
|
||||
result2 = sim.joint_state_callback(msg2, now=0.05)
|
||||
assert result2 == {}
|
||||
|
||||
def test_throttle_passes_after_interval(self):
|
||||
"""超过发送间隔后消息通过"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=0.0, min_interval=0.1)
|
||||
msg1 = _make_joint_state_msg(["arm_j1"], [1.0])
|
||||
sim.joint_state_callback(msg1, now=0.0)
|
||||
|
||||
msg2 = _make_joint_state_msg(["arm_j1"], [2.0])
|
||||
result2 = sim.joint_state_callback(msg2, now=0.15)
|
||||
assert "arm" in result2
|
||||
|
||||
def test_throttle_bypassed_by_resource_change(self):
|
||||
"""resource_pose 变化时忽略抑频限制"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=0.0, min_interval=1.0)
|
||||
msg1 = _make_joint_state_msg(["arm_j1"], [1.0])
|
||||
sim.joint_state_callback(msg1, now=0.0)
|
||||
|
||||
# 资源变化 → 强制发送
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate": "arm_gripper"})))
|
||||
msg2 = _make_joint_state_msg(["arm_j1"], [1.0])
|
||||
result2 = sim.joint_state_callback(msg2, now=0.01) # 远小于 1.0 间隔
|
||||
assert "arm" in result2
|
||||
assert result2["arm"]["resource_poses"] == {"plate": "arm_gripper"}
|
||||
|
||||
# --- 增量 resource_poses ---
|
||||
|
||||
def test_resource_poses_only_sent_when_dirty(self):
|
||||
"""resource_poses 仅在 dirty 时附带,否则为空"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate": "panda_gripper"})))
|
||||
|
||||
# 第一次发送:dirty → 携带 resource_poses
|
||||
msg1 = _make_joint_state_msg(["panda_j1"], [0.5])
|
||||
result1 = sim.joint_state_callback(msg1)
|
||||
assert result1["panda"]["resource_poses"] == {"plate": "panda_gripper"}
|
||||
|
||||
# dirty 已清除
|
||||
assert sim._resource_poses_dirty is False
|
||||
|
||||
# 第二次发送:not dirty → resource_poses 为空
|
||||
msg2 = _make_joint_state_msg(["panda_j1"], [1.0])
|
||||
result2 = sim.joint_state_callback(msg2)
|
||||
assert result2["panda"]["resource_poses"] == {}
|
||||
|
||||
def test_resource_change_resets_dirty_after_send(self):
|
||||
"""dirty 在发送后被重置,再次 resource_pose 变化后重新标记"""
|
||||
sim = _make_sim({"panda": "uuid-panda"})
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate": "panda_deck"})))
|
||||
|
||||
msg = _make_joint_state_msg(["panda_j1"], [0.5])
|
||||
sim.joint_state_callback(msg)
|
||||
assert sim._resource_poses_dirty is False
|
||||
|
||||
# 再次资源变化
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate": "panda_gripper"})))
|
||||
assert sim._resource_poses_dirty is True
|
||||
|
||||
msg2 = _make_joint_state_msg(["panda_j1"], [1.0])
|
||||
result2 = sim.joint_state_callback(msg2)
|
||||
assert result2["panda"]["resource_poses"] == {"plate": "panda_gripper"}
|
||||
|
||||
# --- 组合场景 ---
|
||||
|
||||
def test_dead_band_bypassed_by_resource_dirty(self):
|
||||
"""关节无变化但 resource_pose 有变化 → 仍然发送"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=0.01, min_interval=0.0)
|
||||
msg1 = _make_joint_state_msg(["arm_j1"], [1.0])
|
||||
sim.joint_state_callback(msg1, now=0.0)
|
||||
|
||||
sim.resource_pose_callback(_make_string_msg(json.dumps({"plate": "arm_gripper"})))
|
||||
# 关节值完全不变
|
||||
msg2 = _make_joint_state_msg(["arm_j1"], [1.0])
|
||||
result2 = sim.joint_state_callback(msg2, now=1.0)
|
||||
assert "arm" in result2
|
||||
assert result2["arm"]["resource_poses"] == {"plate": "arm_gripper"}
|
||||
|
||||
def test_high_frequency_stream_only_significant_pass(self):
|
||||
"""模拟高频流: 只有显著变化的消息通过"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=0.01, min_interval=0.0)
|
||||
t = 0.0
|
||||
passed_count = 0
|
||||
|
||||
# 100 条消息,每条微小递增 0.001
|
||||
for i in range(100):
|
||||
t += 0.1
|
||||
val = 1.0 + i * 0.001
|
||||
msg = _make_joint_state_msg(["arm_j1"], [val])
|
||||
result = sim.joint_state_callback(msg, now=t)
|
||||
if result:
|
||||
passed_count += 1
|
||||
|
||||
# 首次总通过 + 每 10 条左右(累计 0.01 变化)通过一次
|
||||
assert passed_count < 20 # 远少于 100
|
||||
assert passed_count >= 5 # 但不应为 0
|
||||
|
||||
def test_throttle_and_dead_band_combined(self):
|
||||
"""同时受抑频和死区影响"""
|
||||
sim = JointBridgeSimulator({"arm": "uuid-arm"}, dead_band=0.01, min_interval=0.5)
|
||||
|
||||
# 首条通过
|
||||
msg1 = _make_joint_state_msg(["arm_j1"], [1.0])
|
||||
assert sim.joint_state_callback(msg1, now=0.0) != {}
|
||||
|
||||
# 时间不够 + 变化不够 → 过滤
|
||||
msg2 = _make_joint_state_msg(["arm_j1"], [1.001])
|
||||
assert sim.joint_state_callback(msg2, now=0.1) == {}
|
||||
|
||||
# 时间够但变化不够 → 过滤
|
||||
msg3 = _make_joint_state_msg(["arm_j1"], [1.002])
|
||||
assert sim.joint_state_callback(msg3, now=1.0) == {}
|
||||
|
||||
# 时间够且变化够 → 通过
|
||||
msg4 = _make_joint_state_msg(["arm_j1"], [1.05])
|
||||
assert sim.joint_state_callback(msg4, now=1.5) != {}
|
||||
@@ -50,6 +50,17 @@ class BaseCommunicationClient(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def publish_joint_state(self, node_uuid: str, joint_states: dict, resource_poses: dict = None) -> None:
|
||||
"""
|
||||
发布高频关节状态数据(push_joint_state action,不写 DB)
|
||||
|
||||
Args:
|
||||
node_uuid: 设备节点的云端 UUID
|
||||
joint_states: 关节名 → 角度/位置 的映射
|
||||
resource_poses: 物料附着映射(可选)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def publish_job_status(
|
||||
self, feedback_data: dict, job_id: str, status: str, return_info: Optional[dict] = None
|
||||
|
||||
210
unilabos/app/model_upload.py
Normal file
210
unilabos/app/model_upload.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""模型文件上传/下载管理。
|
||||
|
||||
提供 Edge 端本地模型文件与 OSS 之间的双向同步:
|
||||
- upload_device_model: 本地模型 → OSS(Edge 首次接入时)
|
||||
- download_model_from_oss: OSS → 本地(新 Edge 加入已有 Lab 时)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from unilabos.utils.log import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unilabos.app.web.client import HTTPClient
|
||||
|
||||
# 设备 mesh 根目录
|
||||
_MESH_BASE_DIR = Path(__file__).parent.parent / "device_mesh"
|
||||
|
||||
# 支持的模型文件后缀
|
||||
_MODEL_EXTENSIONS = frozenset({
|
||||
".xacro", ".urdf", ".stl", ".dae", ".obj",
|
||||
".gltf", ".glb", ".fbx", ".yaml", ".yml",
|
||||
})
|
||||
|
||||
# 需要 XOR 加密/解密的 mesh 文件后缀(反爬保护 — 方案 C)
|
||||
_MESH_ENCRYPT_EXTENSIONS = frozenset({
|
||||
".stl", ".dae", ".obj", ".fbx", ".gltf", ".glb",
|
||||
})
|
||||
|
||||
# XOR 密钥 — 从环境变量读取,与前端 mesh-decrypt.ts 一致
|
||||
_XOR_KEY = os.environ.get("UNILAB_MESH_XOR_KEY", "unilab3d-model-protection-key-v1").encode()
|
||||
|
||||
|
||||
def _xor_transform(data: bytes, key: bytes = _XOR_KEY) -> bytes:
|
||||
"""XOR 加密/解密(对称操作)。"""
|
||||
key_len = len(key)
|
||||
return bytes(b ^ key[i % key_len] for i, b in enumerate(data))
|
||||
|
||||
|
||||
def upload_device_model(
|
||||
http_client: "HTTPClient",
|
||||
template_uuid: str,
|
||||
mesh_name: str,
|
||||
model_type: str,
|
||||
version: str = "1.0.0",
|
||||
) -> Optional[str]:
|
||||
"""上传本地模型文件到 OSS,返回入口文件的 OSS URL。
|
||||
|
||||
Args:
|
||||
http_client: HTTPClient 实例
|
||||
template_uuid: 设备模板 UUID
|
||||
mesh_name: mesh 目录名(如 "arm_slider")
|
||||
model_type: "device" 或 "resource"
|
||||
version: 模型版本
|
||||
|
||||
Returns:
|
||||
入口文件 OSS URL,上传失败返回 None
|
||||
"""
|
||||
if model_type == "device":
|
||||
model_dir = _MESH_BASE_DIR / "devices" / mesh_name
|
||||
else:
|
||||
model_dir = _MESH_BASE_DIR / "resources" / mesh_name
|
||||
|
||||
if not model_dir.exists():
|
||||
logger.warning(f"[模型上传] 本地目录不存在: {model_dir}")
|
||||
return None
|
||||
|
||||
# 收集所有需要上传的文件
|
||||
files = []
|
||||
for f in model_dir.rglob("*"):
|
||||
if f.is_file() and f.suffix.lower() in _MODEL_EXTENSIONS:
|
||||
files.append({
|
||||
"name": str(f.relative_to(model_dir)),
|
||||
"size_kb": f.stat().st_size // 1024,
|
||||
})
|
||||
|
||||
if not files:
|
||||
logger.warning(f"[模型上传] 目录中无可上传的模型文件: {model_dir}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. 获取预签名上传 URL
|
||||
upload_urls_resp = http_client.get_model_upload_urls(
|
||||
template_uuid=template_uuid,
|
||||
files=[{"name": f["name"], "version": version} for f in files],
|
||||
)
|
||||
if not upload_urls_resp:
|
||||
return None
|
||||
|
||||
url_items = upload_urls_resp.get("files", [])
|
||||
|
||||
# 2. 逐个上传文件
|
||||
for file_info, url_info in zip(files, url_items):
|
||||
local_path = model_dir / file_info["name"]
|
||||
upload_url = url_info.get("upload_url", "")
|
||||
if not upload_url:
|
||||
continue
|
||||
_put_upload(local_path, upload_url)
|
||||
|
||||
# 3. 确认发布
|
||||
entry_file = "macro_device.xacro" if model_type == "device" else "modal.xacro"
|
||||
# 检查入口文件是否存在,使用实际存在的文件名
|
||||
for f in files:
|
||||
if f["name"].endswith(".xacro"):
|
||||
entry_file = f["name"]
|
||||
break
|
||||
|
||||
publish_resp = http_client.publish_model(
|
||||
template_uuid=template_uuid,
|
||||
version=version,
|
||||
entry_file=entry_file,
|
||||
)
|
||||
return publish_resp.get("path") if publish_resp else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[模型上传] 上传失败 ({mesh_name}): {e}")
|
||||
return None
|
||||
|
||||
|
||||
def download_model_from_oss(
|
||||
model_config: dict,
|
||||
mesh_base_dir: Optional[Path] = None,
|
||||
) -> bool:
|
||||
"""检查本地模型文件是否存在,不存在则从 OSS 下载。
|
||||
|
||||
Args:
|
||||
model_config: 节点的 model 配置字典
|
||||
mesh_base_dir: mesh 根目录,默认使用 device_mesh/
|
||||
|
||||
Returns:
|
||||
True 表示本地文件就绪,False 表示下载失败或无需下载
|
||||
"""
|
||||
if mesh_base_dir is None:
|
||||
mesh_base_dir = _MESH_BASE_DIR
|
||||
|
||||
mesh_name = model_config.get("mesh", "")
|
||||
model_type = model_config.get("type", "")
|
||||
oss_path = model_config.get("path", "")
|
||||
|
||||
if not mesh_name or not oss_path or not oss_path.startswith("https://"):
|
||||
return False
|
||||
|
||||
# 确定本地目标目录
|
||||
if model_type == "device":
|
||||
local_dir = mesh_base_dir / "devices" / mesh_name
|
||||
elif model_type == "resource":
|
||||
resource_name = mesh_name.split("/")[0]
|
||||
local_dir = mesh_base_dir / "resources" / resource_name
|
||||
else:
|
||||
return False
|
||||
|
||||
# 已有本地文件 → 跳过
|
||||
if local_dir.exists() and any(local_dir.iterdir()):
|
||||
return True
|
||||
|
||||
# 从 OSS 下载
|
||||
local_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
# 下载入口文件(OSS URL 通常直接可访问)
|
||||
entry_name = oss_path.rsplit("/", 1)[-1]
|
||||
_download_file(oss_path, local_dir / entry_name)
|
||||
|
||||
# 如果有 children_mesh,也下载
|
||||
children_mesh = model_config.get("children_mesh")
|
||||
if isinstance(children_mesh, dict) and children_mesh.get("path"):
|
||||
cm_path = children_mesh["path"]
|
||||
if cm_path.startswith("https://"):
|
||||
cm_name = cm_path.rsplit("/", 1)[-1]
|
||||
meshes_dir = local_dir / "meshes"
|
||||
meshes_dir.mkdir(parents=True, exist_ok=True)
|
||||
_download_file(cm_path, meshes_dir / cm_name)
|
||||
|
||||
logger.info(f"[模型下载] 成功下载模型到本地: {mesh_name} → {local_dir}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[模型下载] 下载失败 ({mesh_name}): {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _put_upload(local_path: Path, upload_url: str) -> None:
|
||||
"""通过预签名 URL 上传文件到 OSS。对 mesh 文件自动 XOR 加密。"""
|
||||
with open(local_path, "rb") as f:
|
||||
data = f.read()
|
||||
# 对 mesh 文件 XOR 加密后上传(反爬保护 — 方案 C)
|
||||
if local_path.suffix.lower() in _MESH_ENCRYPT_EXTENSIONS:
|
||||
data = _xor_transform(data)
|
||||
logger.debug(f"[模型上传] XOR 加密: {local_path.name}")
|
||||
resp = requests.put(upload_url, data=data, timeout=120)
|
||||
resp.raise_for_status()
|
||||
logger.debug(f"[模型上传] 已上传: {local_path.name}")
|
||||
|
||||
|
||||
def _download_file(url: str, local_path: Path) -> None:
|
||||
"""下载单个文件到本地路径。对 mesh 文件自动 XOR 解密。"""
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
resp = requests.get(url, timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.content
|
||||
# 从 OSS 下载的 mesh 文件是加密的,需要 XOR 解密后再存本地
|
||||
if local_path.suffix.lower() in _MESH_ENCRYPT_EXTENSIONS:
|
||||
data = _xor_transform(data)
|
||||
logger.debug(f"[模型下载] XOR 解密: {local_path.name}")
|
||||
local_path.write_bytes(data)
|
||||
logger.debug(f"[模型下载] 已下载: {local_path}")
|
||||
@@ -5,6 +5,48 @@ from unilabos.utils.log import logger
|
||||
from unilabos.utils.tools import normalize_json as _normalize_device
|
||||
|
||||
|
||||
def normalize_model_for_upload(model_dict: dict) -> dict:
|
||||
"""将 Registry YAML 的 model 字段映射为后端 DeviceModel 结构化格式。
|
||||
|
||||
保留所有原始字段,额外做以下标准化:
|
||||
1. 自动推断 format(如果 YAML 未指定)
|
||||
2. 将 children_mesh 扁平字段映射为结构化 children_mesh 对象
|
||||
"""
|
||||
if not model_dict:
|
||||
return model_dict
|
||||
|
||||
result = {**model_dict}
|
||||
|
||||
# 自动推断 format
|
||||
if "format" not in result and result.get("path"):
|
||||
path = result["path"]
|
||||
if path.endswith(".xacro"):
|
||||
result["format"] = "xacro"
|
||||
elif path.endswith(".urdf"):
|
||||
result["format"] = "urdf"
|
||||
elif path.endswith(".stl"):
|
||||
result["format"] = "stl"
|
||||
elif path.endswith((".gltf", ".glb")):
|
||||
result["format"] = "gltf"
|
||||
|
||||
# 将 children_mesh 扁平字段 → 结构化 children_mesh 对象
|
||||
if "children_mesh" in result and isinstance(result["children_mesh"], str):
|
||||
cm_path = result.pop("children_mesh")
|
||||
cm_tf = result.pop("children_mesh_tf", None)
|
||||
cm_oss = result.pop("children_mesh_path", None)
|
||||
result["children_mesh"] = {
|
||||
"path": cm_oss or cm_path,
|
||||
"format": "stl" if cm_path.endswith(".stl") else "gltf",
|
||||
"default_visible": True,
|
||||
}
|
||||
if cm_tf and len(cm_tf) >= 3:
|
||||
result["children_mesh"]["local_offset"] = cm_tf[:3]
|
||||
if cm_tf and len(cm_tf) >= 6:
|
||||
result["children_mesh"]["local_rotation"] = cm_tf[3:6]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def register_devices_and_resources(lab_registry, gather_only=False) -> Optional[Tuple[Dict[str, Any], Dict[str, Any]]]:
|
||||
"""
|
||||
注册设备和资源到服务器(仅支持HTTP)
|
||||
@@ -16,11 +58,18 @@ def register_devices_and_resources(lab_registry, gather_only=False) -> Optional[
|
||||
|
||||
devices_to_register = {}
|
||||
for device_info in lab_registry.obtain_registry_device_info():
|
||||
devices_to_register[device_info["id"]] = _normalize_device(device_info)
|
||||
normalized = _normalize_device(device_info)
|
||||
# 标准化 model 字段
|
||||
if normalized.get("model"):
|
||||
normalized["model"] = normalize_model_for_upload(normalized["model"])
|
||||
devices_to_register[device_info["id"]] = normalized
|
||||
logger.trace(f"[UniLab Register] 收集设备: {device_info['id']}")
|
||||
|
||||
resources_to_register = {}
|
||||
for resource_info in lab_registry.obtain_registry_resource_info():
|
||||
# 标准化 model 字段
|
||||
if resource_info.get("model"):
|
||||
resource_info["model"] = normalize_model_for_upload(resource_info["model"])
|
||||
resources_to_register[resource_info["id"]] = resource_info
|
||||
logger.trace(f"[UniLab Register] 收集资源: {resource_info['id']}")
|
||||
|
||||
|
||||
@@ -468,6 +468,63 @@ class HTTPClient:
|
||||
logger.error(f"发布工作流失败: {response.status_code}, {response.text}")
|
||||
return {"code": response.status_code, "message": response.text}
|
||||
|
||||
# ──────────────────── 模型资产管理 ────────────────────
|
||||
|
||||
def get_model_upload_urls(
|
||||
self, template_uuid: str, files: list[dict],
|
||||
) -> dict | None:
|
||||
"""获取模型文件预签名上传 URL。
|
||||
|
||||
Args:
|
||||
template_uuid: 设备模板 UUID
|
||||
files: 文件列表 [{"name": "...", "version": "1.0.0"}]
|
||||
|
||||
Returns:
|
||||
{"files": [{"name": "...", "upload_url": "...", "path": "..."}]}
|
||||
"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.remote_addr}/lab/square/template/{template_uuid}/model/upload-urls",
|
||||
json={"files": files},
|
||||
headers={"Authorization": f"Lab {self.auth}"},
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json().get("data")
|
||||
return data
|
||||
logger.error(f"获取模型上传 URL 失败: {response.status_code}, {response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型上传 URL 异常: {e}")
|
||||
return None
|
||||
|
||||
def publish_model(
|
||||
self, template_uuid: str, version: str, entry_file: str,
|
||||
) -> dict | None:
|
||||
"""确认模型上传完成,发布新版本。
|
||||
|
||||
Args:
|
||||
template_uuid: 设备模板 UUID
|
||||
version: 模型版本
|
||||
entry_file: 入口文件名
|
||||
|
||||
Returns:
|
||||
{"path": "...", "oss_dir": "...", "version": "..."}
|
||||
"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.remote_addr}/lab/square/template/{template_uuid}/model/publish",
|
||||
json={"version": version, "entry_file": entry_file},
|
||||
headers={"Authorization": f"Lab {self.auth}"},
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json().get("data")
|
||||
return data
|
||||
logger.error(f"发布模型失败: {response.status_code}, {response.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"发布模型异常: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# 创建默认客户端实例
|
||||
http_client = HTTPClient()
|
||||
|
||||
@@ -1434,6 +1434,21 @@ class WebSocketClient(BaseCommunicationClient):
|
||||
self.message_processor.send_message(message)
|
||||
# logger.trace(f"[WebSocketClient] Device status published: {device_id}.{property_name}")
|
||||
|
||||
def publish_joint_state(self, node_uuid: str, joint_states: dict, resource_poses: dict = None) -> None:
|
||||
"""发布高频关节状态(push_joint_state,不写 DB)"""
|
||||
if self.is_disabled or not self.is_connected():
|
||||
return
|
||||
|
||||
message = {
|
||||
"action": "push_joint_state",
|
||||
"data": {
|
||||
"node_uuid": node_uuid,
|
||||
"joint_states": joint_states or {},
|
||||
"resource_poses": resource_poses or {},
|
||||
},
|
||||
}
|
||||
self.message_processor.send_message(message)
|
||||
|
||||
def publish_job_status(
|
||||
self, feedback_data: dict, item: QueueItem, status: str, return_info: Optional[dict] = None
|
||||
) -> None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Optional, Dict, Any, List, ClassVar, Set, Unio
|
||||
|
||||
from action_msgs.msg import GoalStatus
|
||||
from geometry_msgs.msg import Point
|
||||
from sensor_msgs.msg import JointState as JointStateMsg
|
||||
from rclpy.action import ActionClient, get_action_server_names_and_types_by_node
|
||||
from rclpy.service import Service
|
||||
from typing_extensions import TypedDict
|
||||
@@ -348,6 +349,10 @@ class HostNode(BaseROS2DeviceNode):
|
||||
else:
|
||||
self.lab_logger().warning(f"[Host Node] Device {device_id} already existed, skipping.")
|
||||
self.update_device_status_subscriptions()
|
||||
|
||||
# 订阅 joint_state_repub topic,桥接关节数据到云端
|
||||
self._init_joint_state_bridge()
|
||||
|
||||
# TODO: 需要验证 初始化所有控制器节点
|
||||
if controllers_config:
|
||||
update_rate = controllers_config["controller_manager"]["ros__parameters"]["update_rate"]
|
||||
@@ -782,6 +787,179 @@ class HostNode(BaseROS2DeviceNode):
|
||||
else:
|
||||
self.lab_logger().trace(f"Status updated: {device_id}.{property_name} = {msg.data}")
|
||||
|
||||
"""关节数据 & 资源跟随桥接"""
|
||||
|
||||
# 吞吐优化参数
|
||||
_JOINT_DEAD_BAND: float = 1e-4 # 关节角度变化小于此值视为无变化
|
||||
_JOINT_MIN_INTERVAL: float = 0.05 # 最小发送间隔 (秒),限制到 ~20Hz
|
||||
|
||||
def _init_joint_state_bridge(self):
|
||||
"""
|
||||
订阅 /joint_states (sensor_msgs/JointState) 和 resource_pose (String),
|
||||
构建 device_id → uuid 映射,并维护 resource_poses 状态。
|
||||
|
||||
吞吐优化:
|
||||
- 死区过滤 (dead band): 关节角度变化 < 阈值时不发送
|
||||
- 抑频 (throttle): 限制最大发送频率,避免 ROS2 1kHz 打满 WS
|
||||
- 增量 resource_poses: 仅在 resource_pose 实际变化时才附带发送
|
||||
"""
|
||||
# 构建 device_id → cloud_uuid 映射(从 devices_config 中获取)
|
||||
self._device_uuid_map: Dict[str, str] = {}
|
||||
for tree in self.devices_config.trees:
|
||||
node = tree.root_node
|
||||
if node.res_content.type == "device" and node.res_content.uuid:
|
||||
self._device_uuid_map[node.res_content.id] = node.res_content.uuid
|
||||
|
||||
# 按 device_id 长度降序排列,最长前缀优先匹配(避免 arm 抢先匹配 arm_left_j1)
|
||||
self._device_ids_sorted = sorted(self._device_uuid_map.keys(), key=len, reverse=True)
|
||||
|
||||
# 资源挂载状态:{resource_id: parent_link_name}
|
||||
self._resource_poses: Dict[str, str] = {}
|
||||
# resource_pose 变化标志,仅在真正变化时随关节数据发送
|
||||
self._resource_poses_dirty: bool = False
|
||||
|
||||
# 吞吐优化状态
|
||||
self._last_joint_values: Dict[str, float] = {} # 上次发送的关节值(全局)
|
||||
self._last_send_time: float = -float("inf") # 上次发送时间戳(初始为-inf确保首条通过)
|
||||
self._last_sent_resource_poses: Dict[str, str] = {} # 上次发送的 resource_poses 快照
|
||||
|
||||
if not self._device_uuid_map:
|
||||
self.lab_logger().debug("[Host Node] 无设备 UUID 映射,跳过关节桥接")
|
||||
return
|
||||
|
||||
# 直接订阅 /joint_states(sensor_msgs/JointState),无需经过 JointRepublisher
|
||||
self.create_subscription(
|
||||
JointStateMsg,
|
||||
"/joint_states",
|
||||
self._joint_state_callback,
|
||||
10,
|
||||
callback_group=self.callback_group,
|
||||
)
|
||||
|
||||
# 订阅 resource_pose(资源挂载变化,由 ResourceMeshManager 发布)
|
||||
from std_msgs.msg import String as StdString
|
||||
self.create_subscription(
|
||||
StdString,
|
||||
"resource_pose",
|
||||
self._resource_pose_callback,
|
||||
10,
|
||||
callback_group=self.callback_group,
|
||||
)
|
||||
|
||||
self.lab_logger().info(
|
||||
f"[Host Node] 已订阅 /joint_states 和 resource_pose,设备映射: {list(self._device_uuid_map.keys())}"
|
||||
)
|
||||
|
||||
def _resource_pose_callback(self, msg):
|
||||
"""
|
||||
接收 ResourceMeshManager 发布的资源挂载变更。
|
||||
|
||||
msg.data 格式: JSON dict,如 {"tip_rack_A1": "gripper_link", "plate_1": "deck_link"}
|
||||
空 dict {} 表示无变化(心跳包)。
|
||||
"""
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return
|
||||
if not isinstance(data, dict) or not data:
|
||||
return
|
||||
# 检测实际变化
|
||||
has_change = False
|
||||
for k, v in data.items():
|
||||
if self._resource_poses.get(k) != v:
|
||||
has_change = True
|
||||
break
|
||||
if has_change:
|
||||
self._resource_poses.update(data)
|
||||
self._resource_poses_dirty = True
|
||||
|
||||
def _joint_state_callback(self, msg: JointStateMsg):
|
||||
"""
|
||||
直接接收 /joint_states (sensor_msgs/JointState),按设备分组后通过 bridge 发送到云端。
|
||||
|
||||
吞吐优化:
|
||||
1. 抑频: 距上次发送 < _JOINT_MIN_INTERVAL 则跳过(除非有 resource_pose 变化)
|
||||
2. 死区: 所有关节角度变化 < _JOINT_DEAD_BAND 则跳过(除非有 resource_pose 变化)
|
||||
3. 增量 resource_poses: 仅在 dirty 时附带,否则发空 dict
|
||||
"""
|
||||
names = list(msg.name)
|
||||
positions = list(msg.position)
|
||||
if not names or len(names) != len(positions):
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
resource_dirty = self._resource_poses_dirty
|
||||
|
||||
# 抑频检查:resource_pose 变化时强制发送
|
||||
if not resource_dirty and (now - self._last_send_time) < self._JOINT_MIN_INTERVAL:
|
||||
return
|
||||
|
||||
# 死区过滤:检测是否有关节值实质变化
|
||||
has_significant_change = False
|
||||
for name, pos in zip(names, positions):
|
||||
last_val = self._last_joint_values.get(name)
|
||||
if last_val is None or abs(float(pos) - last_val) >= self._JOINT_DEAD_BAND:
|
||||
has_significant_change = True
|
||||
break
|
||||
|
||||
# 无关节变化且无资源变化 → 跳过
|
||||
if not has_significant_change and not resource_dirty:
|
||||
return
|
||||
|
||||
# 更新上次发送的关节值
|
||||
for name, pos in zip(names, positions):
|
||||
self._last_joint_values[name] = float(pos)
|
||||
self._last_send_time = now
|
||||
|
||||
# 按设备 ID 分组关节数据(最长前缀优先匹配)
|
||||
device_joints: Dict[str, Dict[str, float]] = {}
|
||||
for name, pos in zip(names, positions):
|
||||
matched_device = None
|
||||
for device_id in self._device_ids_sorted:
|
||||
if name.startswith(device_id + "_"):
|
||||
matched_device = device_id
|
||||
break
|
||||
|
||||
if matched_device:
|
||||
if matched_device not in device_joints:
|
||||
device_joints[matched_device] = {}
|
||||
device_joints[matched_device][name] = float(pos)
|
||||
elif len(self._device_uuid_map) == 1:
|
||||
fallback_id = self._device_ids_sorted[0]
|
||||
if fallback_id not in device_joints:
|
||||
device_joints[fallback_id] = {}
|
||||
device_joints[fallback_id][name] = float(pos)
|
||||
|
||||
# 构建设备级 resource_poses(仅在 dirty 时附带实际数据)
|
||||
device_resource_poses: Dict[str, Dict[str, str]] = {}
|
||||
if resource_dirty:
|
||||
for resource_id, link_name in self._resource_poses.items():
|
||||
matched_device = None
|
||||
for device_id in self._device_ids_sorted:
|
||||
if link_name.startswith(device_id + "_"):
|
||||
matched_device = device_id
|
||||
break
|
||||
if matched_device:
|
||||
if matched_device not in device_resource_poses:
|
||||
device_resource_poses[matched_device] = {}
|
||||
device_resource_poses[matched_device][resource_id] = link_name
|
||||
elif len(self._device_uuid_map) == 1:
|
||||
fallback_id = self._device_ids_sorted[0]
|
||||
if fallback_id not in device_resource_poses:
|
||||
device_resource_poses[fallback_id] = {}
|
||||
device_resource_poses[fallback_id][resource_id] = link_name
|
||||
self._resource_poses_dirty = False
|
||||
|
||||
# 通过 bridge 发送 push_joint_state(含 resource_poses)
|
||||
for device_id, joint_states in device_joints.items():
|
||||
node_uuid = self._device_uuid_map.get(device_id)
|
||||
if not node_uuid:
|
||||
continue
|
||||
resource_poses = device_resource_poses.get(device_id, {})
|
||||
for bridge in self.bridges:
|
||||
if hasattr(bridge, "publish_joint_state"):
|
||||
bridge.publish_joint_state(node_uuid, joint_states, resource_poses)
|
||||
|
||||
def send_goal(
|
||||
self,
|
||||
item: "QueueItem",
|
||||
|
||||
@@ -41,7 +41,7 @@ class JointRepublisher(BaseROS2DeviceNode):
|
||||
json_dict["velocity"] = list(msg.velocity)
|
||||
json_dict["effort"] = list(msg.effort)
|
||||
|
||||
self.msg.data = str(json_dict)
|
||||
self.msg.data = json.dumps(json_dict)
|
||||
self.joint_repub.publish(self.msg)
|
||||
# print('-'*20)
|
||||
# print(self.msg.data)
|
||||
|
||||
Reference in New Issue
Block a user