Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added laser_scan_matcher/models/model_conv_16.pth
Binary file not shown.
Binary file added laser_scan_matcher/models/model_match_16.pth
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
116 changes: 0 additions & 116 deletions laser_scan_matcher/scripts/config.py

This file was deleted.

41 changes: 5 additions & 36 deletions laser_scan_matcher/scripts/match_laser_scans.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,18 @@
from sensor_msgs.msg import LaserScan
from laser_scan_matcher.srv import MatchLaserScans, MatchLaserScansResponse
import sys
import os
from os import path
import torch

# TODO fix this hack
sys.path.append(path.dirname(path.dirname(path.dirname( path.dirname( path.abspath(__file__) ) ) )))
sys.path.append(path.dirname(__file__))

from model import FullNet, EmbeddingNet, LCCNet, DistanceNet, StructuredEmbeddingNet, ScanMatchNet, ScanConvNet, ScanTransformNet, ScanSingleConvNet, ScanUncertaintyNet

def create_laser_networks(model_dir, model_epoch, multi_gpu=True):
scan_conv = ScanConvNet()
if model_dir:
scan_conv.load_state_dict(torch.load(os.path.join(model_dir, 'model_conv_' + model_epoch + '.pth')))

scan_transform = ScanTransformNet()
if model_dir:
transform_path = os.path.join(model_dir, 'model_transform_' + model_epoch + '.pth')
if os.path.exists(transform_path):
scan_transform.load_state_dict(torch.load(transform_path))
else:
print("Warning: no `transform` network found for provided model_dir and epoch")

scan_match = ScanMatchNet()
if model_dir:
scan_match.load_state_dict(torch.load(os.path.join(model_dir, 'model_match_' + model_epoch + '.pth')))

if multi_gpu and torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
scan_conv = torch.nn.DataParallel(scan_conv)
scan_match = torch.nn.DataParallel(scan_match)
scan_transform = torch.nn.DataParallel(scan_transform)

scan_conv.cuda()
scan_match.cuda()
scan_transform.cuda()
return scan_conv, scan_match, scan_transform
sys.path.append(path.join(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))), 'learning_loop_closure'))
from helpers import create_laser_networks

def create_scan_match_helper(scan_conv, scan_match):
def match_scans(req):

scan_np = np.array(req.scan.ranges).astype(np.float32)
alt_scan_np = np.array(req.alt_scan.ranges).astype(np.float32)
scan_np = np.minimum(np.array(req.scan.ranges).astype(np.float32), 30)
alt_scan_np = np.minimum(np.array(req.alt_scan.ranges).astype(np.float32), 30)

scan = torch.tensor(scan_np).cuda()
alt_scan = torch.tensor(alt_scan_np).cuda()
Expand Down Expand Up @@ -75,4 +44,4 @@ def service():
try:
service()
except rospy.ROSInterruptException:
pass
pass
Loading