Skip to content

Commit

Permalink
Minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Tobias-Fischer committed Oct 25, 2019
1 parent daa398d commit 059c8ac
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 31 deletions.
23 changes: 10 additions & 13 deletions rt_gene/scripts/estimate_gaze.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __init__(self, device_id_gaze, model_files):
self.headpose_frame = self.tf_prefix + "/head_pose_estimated"
self.ros_tf_frame = rospy.get_param("~ros_tf_frame", "/kinect2_nonrotated_link")

self.image_subscriber = rospy.Subscriber('/subjects/images', MSG_SubjectImagesList, self.image_callback, queue_size=3, buff_size=2**24)
self.subjects_gaze_img = rospy.Publisher('/subjects/gazeimages', Image, queue_size=3)
self.image_subscriber = rospy.Subscriber("/subjects/images", MSG_SubjectImagesList, self.image_callback, queue_size=3, buff_size=2**24)
self.subjects_gaze_img = rospy.Publisher("/subjects/gazeimages", Image, queue_size=3)

self.visualise_eyepose = rospy.get_param("~visualise_eyepose", default=True)

Expand All @@ -59,7 +59,6 @@ def image_callback(self, subject_image_list):
and this image is published along with the estimated gaze vector (see :meth:`publish_image` and
:func:`publish_gaze`)"""
timestamp = subject_image_list.header.stamp
subjects_gaze_img = None

subjects_dict = self.subjects_bridge.msg_to_images(subject_image_list)
input_r_list = []
Expand Down Expand Up @@ -87,6 +86,7 @@ def image_callback(self, subject_image_list):
inference_input_right_list=input_r_list,
inference_headpose_list=input_head_list)

subjects_gaze_img_list = []
for subject_id, gaze in zip(valid_subject_list, gaze_est.tolist()):
self.publish_gaze(gaze, timestamp, subject_id)

Expand All @@ -95,14 +95,11 @@ def image_callback(self, subject_image_list):
r_gaze_img = self.visualize_eye_result(s.right, gaze)
l_gaze_img = self.visualize_eye_result(s.left, gaze)
s_gaze_img = np.concatenate((r_gaze_img, l_gaze_img), axis=1)
subjects_gaze_img_list.append(s_gaze_img)

if subjects_gaze_img is None:
subjects_gaze_img = s_gaze_img
else:
subjects_gaze_img = np.concatenate((subjects_gaze_img, s_gaze_img), axis=0)

if subjects_gaze_img is not None:
gaze_img_msg = self.bridge.cv2_to_imgmsg(subjects_gaze_img.astype(np.uint8), "bgr8")
if len(subjects_gaze_img_list) > 0:
gaze_img_msg = self.bridge.cv2_to_imgmsg(np.hstack(subjects_gaze_img_list).astype(np.uint8), "bgr8")
gaze_img_msg.header.stamp = timestamp
self.subjects_gaze_img.publish(gaze_img_msg)

def publish_gaze(self, est_gaze, msg_stamp, subject_id):
Expand All @@ -115,11 +112,11 @@ def publish_gaze(self, est_gaze, msg_stamp, subject_id):
quaternion_gaze, msg_stamp, self.tf_prefix + "/world_gaze" + str(subject_id), self.headpose_frame + str(subject_id))


if __name__ == '__main__':
if __name__ == "__main__":
try:
rospy.init_node('estimate_gaze')
rospy.init_node("estimate_gaze")
gaze_estimator = GazeEstimatorROS(rospy.get_param("~device_id_gazeestimation", default="/gpu:0"),
[os.path.join(rospkg.RosPack().get_path('rt_gene'), model_file) for model_file in rospy.get_param("~model_files")])
[os.path.join(rospkg.RosPack().get_path("rt_gene"), model_file) for model_file in rospy.get_param("~model_files")])
rospy.spin()
except rospy.exceptions.ROSInterruptException:
print("See ya")
Expand Down
13 changes: 4 additions & 9 deletions rt_gene/scripts/extract_landmarks_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def process_image(self, color_msg):

self.subject_tracker.update_eye_images(self.eye_image_size)

final_head_pose_images = None
final_head_pose_images = []
for subject_id, subject in self.subject_tracker.get_tracked_elements().items():
if subject.left_eye_color is None or subject.right_eye_color is None:
continue
Expand All @@ -114,20 +114,15 @@ def process_image(self, color_msg):
roll_pitch_yaw = gaze_tools.limit_yaw(head_rpy)
face_image_resized = cv2.resize(subject.face_color, dsize=(224, 224), interpolation=cv2.INTER_CUBIC)

head_pose_image = LandmarkMethodROS.visualize_headpose_result(face_image_resized, gaze_tools.get_phi_theta_from_euler(roll_pitch_yaw))

if final_head_pose_images is None:
final_head_pose_images = head_pose_image
else:
final_head_pose_images = np.concatenate((final_head_pose_images, head_pose_image), axis=1)
final_head_pose_images.append(LandmarkMethodROS.visualize_headpose_result(face_image_resized, gaze_tools.get_phi_theta_from_euler(roll_pitch_yaw)))
else:
tqdm.write("Could not get head pose properly")

if len(self.subject_tracker.get_tracked_elements().items()) > 0:
self.publish_subject_list(timestamp, self.subject_tracker.get_tracked_elements())

if final_head_pose_images is not None:
headpose_image_ros = self.bridge.cv2_to_imgmsg(final_head_pose_images, "bgr8")
if len(final_head_pose_images) > 0:
headpose_image_ros = self.bridge.cv2_to_imgmsg(np.hstack(final_head_pose_images), "bgr8")
headpose_image_ros.header.stamp = timestamp
self.subject_faces_pub.publish(headpose_image_ros)

Expand Down
5 changes: 2 additions & 3 deletions rt_gene/src/rt_gene/estimate_gaze_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import cv2
import numpy as np
import tensorflow as tf
from rt_gene.gaze_tools import accuracy_angle, angle_loss, get_endpoint
from rt_gene.gaze_tools import get_endpoint
from tqdm import tqdm


Expand Down Expand Up @@ -47,8 +47,7 @@ def __init__(self, device_id_gaze, model_files):

for model_file in model_files:
tqdm.write('Load model ' + model_file)
models.append(tf.keras.models.load_model(model_file,
custom_objects={'accuracy_angle': accuracy_angle, 'angle_loss': angle_loss}))
models.append(tf.keras.models.load_model(model_file, compile=False))
# noinspection PyProtectedMember
models[-1]._name = "model_{}".format(len(models))

Expand Down
6 changes: 3 additions & 3 deletions rt_gene/src/rt_gene/subject_ros_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(self):

def msg_to_images(self, subject_msg):
subject = SubjectImages(subject_msg.subject_id)
subject.face = self.__cv_bridge.imgmsg_to_cv2(subject_msg.face_img, 'rgb8')
subject.right = self.__cv_bridge.imgmsg_to_cv2(subject_msg.right_eye_img, 'rgb8')
subject.left = self.__cv_bridge.imgmsg_to_cv2(subject_msg.left_eye_img, 'rgb8')
subject.face = self.__cv_bridge.imgmsg_to_cv2(subject_msg.face_img, "rgb8")
subject.right = self.__cv_bridge.imgmsg_to_cv2(subject_msg.right_eye_img, "rgb8")
subject.left = self.__cv_bridge.imgmsg_to_cv2(subject_msg.left_eye_img, "rgb8")
return subject

def images_to_msg(self, subject_id, subject):
Expand Down
2 changes: 1 addition & 1 deletion rt_gene_inpainting/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments
[![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg?style=flat-square)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
[![HitCount](http://hits.dwyl.io/Tobias-Fischer/rt_gene.svg)](http://hits.dwyl.io/Tobias-Fischer/rt_gene)
![HitCount](http://hits.dwyl.io/Tobias-Fischer/rt_gene.svg)
![stars](https://img.shields.io/github/stars/Tobias-Fischer/rt_gene.svg?style=flat-square)
![GitHub issues](https://img.shields.io/github/issues/Tobias-Fischer/rt_gene.svg?style=flat-square)
![GitHub repo size](https://img.shields.io/github/repo-size/Tobias-Fischer/rt_gene.svg?style=flat-square)
Expand Down
2 changes: 1 addition & 1 deletion rt_gene_model_training/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments
[![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg?style=flat-square)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
[![HitCount](http://hits.dwyl.io/Tobias-Fischer/rt_gene.svg)](http://hits.dwyl.io/Tobias-Fischer/rt_gene)
![HitCount](http://hits.dwyl.io/Tobias-Fischer/rt_gene.svg)
![stars](https://img.shields.io/github/stars/Tobias-Fischer/rt_gene.svg?style=flat-square)
![GitHub issues](https://img.shields.io/github/issues/Tobias-Fischer/rt_gene.svg?style=flat-square)
![GitHub repo size](https://img.shields.io/github/repo-size/Tobias-Fischer/rt_gene.svg?style=flat-square)
Expand Down
2 changes: 1 addition & 1 deletion rt_gene_standalone/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# RT-GENE: Real-Time Eye Gaze Estimation in Natural Environments
[![License: CC BY-NC-SA 4.0](https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg?style=flat-square)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
[![HitCount](http://hits.dwyl.io/Tobias-Fischer/rt_gene.svg)](http://hits.dwyl.io/Tobias-Fischer/rt_gene)
![HitCount](http://hits.dwyl.io/Tobias-Fischer/rt_gene.svg)
![stars](https://img.shields.io/github/stars/Tobias-Fischer/rt_gene.svg?style=flat-square)
![GitHub issues](https://img.shields.io/github/issues/Tobias-Fischer/rt_gene.svg?style=flat-square)
![GitHub repo size](https://img.shields.io/github/repo-size/Tobias-Fischer/rt_gene.svg?style=flat-square)
Expand Down

0 comments on commit 059c8ac

Please sign in to comment.