woolford180 Here's what's supposed to happen:
in modeltrain.py:
if msg_type == 'SetDevice':
print("Setting device...\n", file=sys.stderr)
device_id=tc.decode_strings(msg_data)[0]
device = torch.device(device_id)
pin_memory = True if 'cuda' in device_id else False
if msg_type == 'SetTorchScriptModel' and modules_valid:
print("Setting torchscript model...\n", file=sys.stderr)
buffer=io.BytesIO(msg_data)
model = torch.jit.load(buffer, map_location=device)
First TorchStudio send a 'SetDevice' message, triggering the first block ("Setting Device..."):
1/ decode the device name id from TorchStudio (usually something like "cpu" or "cuda:0")
2/ pytorch converts that device name id to a pytorch device
3/ if the device is a CUDA device, set a pin_memory flag to True
Then TorchStudio send a 'SetTorchScriptModel' message, triggering the second block ("Setting torchscript model...") but it seems the script never reach that point, otherwise you would see that new message instead.
So I suppose the script freeze at device = torch.device(device_id)
.
Could you edit the modeltrain.py with some additional messages like this:
if msg_type == 'SetDevice':
print("Setting device...\n", file=sys.stderr)
device_id=tc.decode_strings(msg_data)[0]
print("device id:", device_id, "\n", file=sys.stderr)
device = torch.device(device_id)
print("device set\n", file=sys.stderr)
pin_memory = True if 'cuda' in device_id else False
print("pin_memory set\n", file=sys.stderr)
Then run it, quit TorchStudio if the problem still happen, and send me the log files located in ~/TorchStudio/logs (if you're on mac or linux) or %USERPROFILE%/TorchStudio/logs (if you're on windows).
If needed my email it can be found in the About page, under Contact: https://torchstudio.ai/about/
Thanks !