Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Mauro Rovezzi
bliss
Commits
e16f4398
Commit
e16f4398
authored
Dec 08, 2017
by
Sebastien Petitdemange
Committed by
Vincent Michel
Jan 12, 2018
Browse files
pep8
parent
1caf9975
Changes
23
Expand all
Hide whitespace changes
Inline
Side-by-side
bliss/common/data_file_manager.py
View file @
e16f4398
import
os
,
errno
import
os
import
errno
import
h5py
from
bliss.scanning.chain
import
AcquisitionDevice
,
AcquisitionMaster
from
bliss.scanning.chain
import
AcquisitionDevice
,
AcquisitionMaster
class
FileOrganizer
(
object
):
def
__init__
(
self
,
root_path
,
def
__init__
(
self
,
root_path
,
windows_path_mapping
=
None
,
detector_temporay_path
=
None
,
**
keys
):
detector_temporay_path
=
None
,
**
keys
):
""" A default way to organize file structure
windows_path_mapping -- transform unix path to windows
...
...
@@ -16,36 +18,41 @@ class FileOrganizer(object):
self
.
_root_path
=
root_path
self
.
_windows_path_mapping
=
windows_path_mapping
or
dict
()
self
.
_detector_temporay_path
=
detector_temporay_path
or
dict
()
class
Hdf5Organizer
(
FileOrganizer
):
def
__init__
(
self
,
root_path
,
**
keys
):
FileOrganizer
.
__init__
(
self
,
root_path
,
**
keys
)
def
__init__
(
self
,
root_path
,
**
keys
):
FileOrganizer
.
__init__
(
self
,
root_path
,
**
keys
)
self
.
file
=
None
def
_acq_device_event
(
self
,
event_dict
=
None
,
signal
=
None
,
sender
=
None
):
print
'received'
,
signal
,
'from'
,
sender
,
":"
,
event_dict
def
prepare
(
self
,
scan_recorder
,
scan_info
,
devices_tree
):
path_suffix
=
scan_recorder
.
node
.
db_name
().
replace
(
':'
,
os
.
path
.
sep
)
full_path
=
os
.
path
.
join
(
self
.
_root_path
,
path_suffix
)
def
prepare
(
self
,
scan_recorder
,
scan_info
,
devices_tree
):
path_suffix
=
scan_recorder
.
node
.
db_name
().
replace
(
':'
,
os
.
path
.
sep
)
full_path
=
os
.
path
.
join
(
self
.
_root_path
,
path_suffix
)
try
:
os
.
makedirs
(
full_path
)
except
OSError
as
exc
:
# Python >2.5
except
OSError
as
exc
:
# Python >2.5
if
exc
.
errno
==
errno
.
EEXIST
and
os
.
path
.
isdir
(
path
):
pass
else
:
raise
else
:
raise
self
.
file
=
h5py
.
File
(
os
.
path
.
join
(
full_path
,
'data.h5'
))
scan_entry
=
h5py
.
Group
(
self
.
file
,
scan_recorder
.
name
,
create
=
True
)
self
.
file
=
h5py
.
File
(
os
.
path
.
join
(
full_path
,
'data.h5'
))
scan_entry
=
h5py
.
Group
(
self
.
file
,
scan_recorder
.
name
,
create
=
True
)
scan_entry
.
attrs
[
'NX_class'
]
=
'NXentry'
measurement
=
h5py
.
Group
(
scan_entry
,
'measurement'
,
create
=
True
)
measurement
=
h5py
.
Group
(
scan_entry
,
'measurement'
,
create
=
True
)
master_id
=
0
for
dev
,
node
in
scan_recorder
.
nodes
.
iteritems
():
if
isinstance
(
dev
,
AcquisitionMaster
):
master_entry
=
h5py
.
Group
(
measurement
,
'master%d'
%
master_id
,
create
=
True
)
for
dev
,
node
in
scan_recorder
.
nodes
.
iteritems
():
if
isinstance
(
dev
,
AcquisitionMaster
):
master_entry
=
h5py
.
Group
(
measurement
,
'master%d'
%
master_id
,
create
=
True
)
master_id
+=
1
for
slave
in
dev
.
slaves
:
if
isinstance
(
slave
,
AcquisitionDevice
):
for
signal
in
(
'start'
,
'end'
,
'new_ref'
,
'new_data'
):
dispatcher
.
connect
(
self
.
_acq_device_event
,
signal
,
dev
)
if
isinstance
(
slave
,
AcquisitionDevice
):
for
signal
in
(
'start'
,
'end'
,
'new_ref'
,
'new_data'
):
dispatcher
.
connect
(
self
.
_acq_device_event
,
signal
,
dev
)
bliss/common/utils.py
View file @
e16f4398
...
...
@@ -11,26 +11,27 @@ import itertools
import
functools
try
:
from
collections
import
OrderedDict
from
collections
import
OrderedDict
except
ImportError
:
# python2.6 compatibility
from
ordereddict
import
OrderedDict
from
ordereddict
import
OrderedDict
class
WrappedMethod
(
object
):
def
__init__
(
self
,
control
,
method_name
):
self
.
method_name
=
method_name
self
.
control
=
control
def
__init__
(
self
,
control
,
method_name
):
self
.
method_name
=
method_name
self
.
control
=
control
def
__call__
(
self
,
this
,
*
args
,
**
kwargs
):
return
getattr
(
self
.
control
,
self
.
method_name
)(
*
args
,
**
kwargs
)
def
__call__
(
self
,
this
,
*
args
,
**
kwargs
):
return
getattr
(
self
.
control
,
self
.
method_name
)(
*
args
,
**
kwargs
)
def
wrap_methods
(
from_object
,
target_object
):
for
name
in
dir
(
from_object
):
if
inspect
.
ismethod
(
getattr
(
from_object
,
name
)):
if
hasattr
(
target_object
,
name
)
and
inspect
.
ismethod
(
getattr
(
target_object
,
name
)):
continue
setattr
(
target_object
,
name
,
types
.
MethodType
(
WrappedMethod
(
from_object
,
name
),
target_object
,
target_object
.
__class__
))
for
name
in
dir
(
from_object
):
if
inspect
.
ismethod
(
getattr
(
from_object
,
name
)):
if
hasattr
(
target_object
,
name
)
and
inspect
.
ismethod
(
getattr
(
target_object
,
name
)):
continue
setattr
(
target_object
,
name
,
types
.
MethodType
(
WrappedMethod
(
from_object
,
name
),
target_object
,
target_object
.
__class__
))
def
add_conversion_function
(
obj
,
method_name
,
function
):
...
...
@@ -40,7 +41,7 @@ def add_conversion_function(obj, method_name, function):
def
new_method
(
*
args
,
**
kwargs
):
values
=
meth
(
*
args
,
**
kwargs
)
return
function
(
values
)
setattr
(
obj
,
method_name
,
new_method
)
setattr
(
obj
,
method_name
,
new_method
)
else
:
raise
ValueError
(
"conversion function must be callable"
)
else
:
...
...
@@ -48,25 +49,29 @@ def add_conversion_function(obj, method_name, function):
def
add_property
(
inst
,
name
,
method
):
cls
=
type
(
inst
)
if
not
hasattr
(
cls
,
'__perinstance'
):
cls
=
type
(
cls
.
__name__
,
(
cls
,),
{})
cls
.
__perinstance
=
True
inst
.
__class__
=
cls
setattr
(
cls
,
name
,
property
(
method
))
cls
=
type
(
inst
)
if
not
hasattr
(
cls
,
'__perinstance'
):
cls
=
type
(
cls
.
__name__
,
(
cls
,),
{})
cls
.
__perinstance
=
True
inst
.
__class__
=
cls
setattr
(
cls
,
name
,
property
(
method
))
def
grouped
(
iterable
,
n
):
"s -> (s0,s1,s2,...sn-1), (sn,sn+1,sn+2,...s2n-1), (s2n,s2n+1,s2n+2,...s3n-1), ..."
return
itertools
.
izip
(
*
[
iter
(
iterable
)]
*
n
)
return
itertools
.
izip
(
*
[
iter
(
iterable
)]
*
n
)
def
all_equal
(
iterable
):
g
=
itertools
.
groupby
(
iterable
)
return
next
(
g
,
True
)
and
not
next
(
g
,
False
)
"""
functions to add custom attributes and commands to an object.
"""
def
add_object_method
(
obj
,
method
,
pre_call
,
name
=
None
,
args
=
[],
types_info
=
(
None
,
None
)):
if
name
is
None
:
...
...
@@ -98,27 +103,31 @@ def object_method(method=None, name=None, args=[], types_info=(None, None), filt
# Returns a method where _object_method_ attribute is filled with a
# dict of elements to characterize it.
method
.
_object_method_
=
dict
(
name
=
name
,
args
=
args
,
types_info
=
types_info
,
filter
=
filter
)
method
.
_object_method_
=
dict
(
name
=
name
,
args
=
args
,
types_info
=
types_info
,
filter
=
filter
)
return
method
def
object_method_type
(
method
=
None
,
name
=
None
,
args
=
[],
types_info
=
(
None
,
None
),
type
=
None
):
f
=
lambda
x
:
isinstance
(
x
,
type
)
def
f
(
x
):
return
isinstance
(
x
,
type
)
return
object_method
(
method
=
method
,
name
=
name
,
args
=
args
,
types_info
=
types_info
,
filter
=
f
)
def
add_object_attribute
(
obj
,
name
=
None
,
fget
=
None
,
fset
=
None
,
args
=
[],
type_info
=
None
,
filter
=
None
):
obj
.
_add_custom_attribute
(
name
,
fget
,
fset
,
type_info
)
"""
decorators for set/get methods to access to custom attributes
"""
def
object_attribute_type_get
(
get_method
=
None
,
name
=
None
,
args
=
[],
type_info
=
None
,
type
=
None
):
f
=
lambda
x
:
isinstance
(
x
,
type
)
def
f
(
x
):
return
isinstance
(
x
,
type
)
return
object_attribute_get
(
get_method
=
get_method
,
name
=
name
,
args
=
args
,
type_info
=
type_info
,
filter
=
f
)
def
object_attribute_get
(
get_method
=
None
,
name
=
None
,
args
=
[],
type_info
=
None
,
filter
=
None
):
if
get_method
is
None
:
return
functools
.
partial
(
object_attribute_get
,
name
=
name
,
args
=
args
,
...
...
@@ -128,21 +137,24 @@ def object_attribute_get(get_method=None, name=None, args=[], type_info=None, fi
name
=
get_method
.
func_name
attr_name
=
name
if
attr_name
.
startswith
(
"get_"
):
attr_name
=
attr_name
[
4
:]
# removes leading "get_"
attr_name
=
attr_name
[
4
:]
# removes leading "get_"
get_method
.
_object_method_
=
dict
(
name
=
name
,
args
=
args
,
types_info
=
(
"None"
,
type_info
),
filter
=
filter
)
get_method
.
_object_method_
=
dict
(
name
=
name
,
args
=
args
,
types_info
=
(
"None"
,
type_info
),
filter
=
filter
)
if
not
hasattr
(
get_method
,
"_object_attribute_"
):
get_method
.
_object_attribute_
=
dict
()
get_method
.
_object_attribute_
.
update
(
name
=
attr_name
,
fget
=
get_method
,
args
=
args
,
type_info
=
type_info
,
filter
=
filter
)
get_method
.
_object_attribute_
.
update
(
name
=
attr_name
,
fget
=
get_method
,
args
=
args
,
type_info
=
type_info
,
filter
=
filter
)
return
get_method
def
object_attribute_type_set
(
set_method
=
None
,
name
=
None
,
args
=
[],
type_info
=
None
,
type
=
None
):
f
=
lambda
x
:
isinstance
(
x
,
type
)
def
f
(
x
):
return
isinstance
(
x
,
type
)
return
object_attribute_set
(
set_method
=
set_method
,
name
=
name
,
args
=
args
,
type_info
=
type_info
,
filter
=
f
)
def
object_attribute_set
(
set_method
=
None
,
name
=
None
,
args
=
[],
type_info
=
None
,
filter
=
None
):
if
set_method
is
None
:
return
functools
.
partial
(
object_attribute_set
,
name
=
name
,
args
=
args
,
...
...
@@ -152,13 +164,15 @@ def object_attribute_set(set_method=None, name=None, args=[], type_info=None, fi
name
=
set_method
.
func_name
attr_name
=
name
if
attr_name
.
startswith
(
"set_"
):
attr_name
=
attr_name
[
4
:]
# removes leading "set_"
attr_name
=
attr_name
[
4
:]
# removes leading "set_"
set_method
.
_object_method_
=
dict
(
name
=
name
,
args
=
args
,
types_info
=
(
type_info
,
"None"
),
filter
=
filter
)
set_method
.
_object_method_
=
dict
(
name
=
name
,
args
=
args
,
types_info
=
(
type_info
,
"None"
),
filter
=
filter
)
if
not
hasattr
(
set_method
,
"_object_attribute_"
):
set_method
.
_object_attribute_
=
dict
()
set_method
.
_object_attribute_
.
update
(
name
=
attr_name
,
fset
=
set_method
,
args
=
args
,
type_info
=
type_info
,
filter
=
filter
)
set_method
.
_object_attribute_
.
update
(
name
=
attr_name
,
fset
=
set_method
,
args
=
args
,
type_info
=
type_info
,
filter
=
filter
)
return
set_method
...
...
@@ -174,7 +188,7 @@ def set_custom_members(src_obj, target_obj, pre_call=None):
attribute_info
=
dict
(
member
.
_object_attribute_
)
filter
=
attribute_info
.
pop
(
'filter'
,
None
)
if
filter
is
None
or
filter
(
target_obj
):
add_object_attribute
(
target_obj
,
**
member
.
_object_attribute_
)
add_object_attribute
(
target_obj
,
**
member
.
_object_attribute_
)
# For each method of <src_obj>: try to add it as a
# custom method or as methods to set/get custom
...
...
@@ -237,7 +251,8 @@ def with_custom_members(klass):
access_mode
=
'r'
if
fget
else
''
access_mode
+=
'w'
if
fset
else
''
if
fget
is
None
and
fset
is
None
:
raise
RuntimeError
(
"impossible case: must have fget or fset..."
)
raise
RuntimeError
(
"impossible case: must have fget or fset..."
)
custom_attrs
[
name
]
=
type_info
,
access_mode
klass
.
_get_custom_methods
=
_get_custom_methods
...
...
@@ -250,7 +265,6 @@ def with_custom_members(klass):
return
klass
class
Null
(
object
):
__slots__
=
[]
...
...
bliss/config/settings.py
View file @
e16f4398
This diff is collapsed.
Click to expand it.
bliss/data/lima.py
View file @
e16f4398
...
...
@@ -14,30 +14,31 @@ from bliss.config.conductor import client
from
bliss.config
import
channels
import
gevent
class
LimaDataNode
(
DataNode
):
class
DataChannel
(
object
):
def
__init__
(
self
,
dataset
):
def
__init__
(
self
,
dataset
):
self
.
_dataset
=
dataset
self
.
_device_proxy
=
None
self
.
_image_mode
=
{
0
:
numpy
.
uint8
,
1
:
numpy
.
uint16
,
2
:
numpy
.
uint32
,
4
:
numpy
.
int8
,
5
:
numpy
.
int16
,
6
:
numpy
.
int32
,
}
def
get
(
self
,
from_index
,
to_index
=
None
):
0
:
numpy
.
uint8
,
1
:
numpy
.
uint16
,
2
:
numpy
.
uint32
,
4
:
numpy
.
int8
,
5
:
numpy
.
int16
,
6
:
numpy
.
int32
,
}
def
get
(
self
,
from_index
,
to_index
=
None
):
cnx
=
self
.
_dataset
.
_data
.
_cnx
()
url
=
self
.
_dataset
.
_data
.
url_server
if
url
is
None
:
# data is no more available
raise
RuntimeError
(
'dataset is no more available'
)
current_lima_acq
=
int
(
cnx
.
get
(
url
))
(
lima_acq_nb
,
acq_nb_buffer
,
LastImageAcquired
,
LastCounterReady
,
(
lima_acq_nb
,
acq_nb_buffer
,
LastImageAcquired
,
LastCounterReady
,
LastImageSaved
)
=
[
int
(
x
)
for
x
in
cnx
.
hmget
(
self
.
db_name
,
'lima_acq_nb'
,
'acq_nb_buffer'
,
...
...
@@ -45,19 +46,19 @@ class LimaDataNode(DataNode):
'LastCounterReady'
,
'LastImageSaved'
)]
if
to_index
is
None
:
#first we try to get image directly from the server
if
current_lima_acq
==
lima_acq_nb
:
# current acquisition
if
LastImageAcquired
<
from_index
:
# image is not yet available
#
first we try to get image directly from the server
if
current_lima_acq
==
lima_acq_nb
:
# current acquisition
if
LastImageAcquired
<
from_index
:
# image is not yet available
raise
RuntimeError
(
'image is not yet available'
)
#should be still in server memory
#
should be still in server memory
if
acq_nb_buffer
>
LastImageAcquired
-
from_index
:
try
:
if
self
.
_device_proxy
is
None
:
self
.
_device_proxy
=
DeviceProxy
(
url
)
raw_msg
=
self
.
_device_proxy
.
readImage
(
from_index
)
return
self
.
_tango_unpack
(
raw_msg
[
-
1
])
except
:
except
:
# As it's asynchronous, image seams to be no
# more available so read it from file
return
self
.
_read_from_file
(
from_index
)
...
...
@@ -66,26 +67,28 @@ class LimaDataNode(DataNode):
else
:
raise
NotImplementedError
(
'Not yet done'
)
def
_tango_unpack
(
self
,
msg
):
def
_tango_unpack
(
self
,
msg
):
struct_format
=
'<IHHIIHHHHHHHHHHHHHHHHHHIII'
header_size
=
struct
.
calcsize
(
struct_format
)
values
=
struct
.
unpack
(
msg
[:
header_size
])
if
values
[
0
]
!=
0x44544159
:
raise
RuntimeError
(
'Not a lima data'
)
header_offset
=
values
[
2
]
data
=
numpy
.
fromstring
(
msg
[
header_offset
:],
data
=
self
.
_image_mode
.
get
(
values
[
4
]))
data
.
shape
=
values
[
8
],
values
[
7
]
data
=
numpy
.
fromstring
(
msg
[
header_offset
:],
data
=
self
.
_image_mode
.
get
(
values
[
4
]))
data
.
shape
=
values
[
8
],
values
[
7
]
return
data
def
_read_from_file
(
self
,
from_index
):
def
_read_from_file
(
self
,
from_index
):
#@todo should read file from any format?????
for
saving_parameters
in
self
.
_dataset
.
_saving_params
:
pass
def
__init__
(
self
,
name
,
**
keys
):
DataNode
.
__init__
(
self
,
'lima'
,
name
,
**
keys
)
def
__init__
(
self
,
name
,
**
keys
):
DataNode
.
__init__
(
self
,
'lima'
,
name
,
**
keys
)
saving_params_name
=
'%s_saving_params'
%
self
.
db_name
()
self
.
_saving_params
=
QueueObjSetting
(
saving_params_name
,
connection
=
self
.
db_connection
)
self
.
_saving_params
=
QueueObjSetting
(
saving_params_name
,
connection
=
self
.
db_connection
)
self
.
_storage_task
=
None
def
channel_name
(
self
):
...
...
@@ -93,13 +96,13 @@ class LimaDataNode(DataNode):
#@brief update image status
#
def
update_status
(
self
,
image_status
):
def
update_status
(
self
,
image_status
):
cnx
=
self
.
_data
.
_cnx
()
db_name
=
self
.
db_name
()
pipeline
=
cnx
.
pipeline
()
for
key
,
value
in
image_status
.
iteritems
():
pipeline
.
hset
(
db_name
,
key
,
value
)
for
key
,
value
in
image_status
.
iteritems
():
pipeline
.
hset
(
db_name
,
key
,
value
)
pipeline
.
execute
()
def
_end_storage
(
self
):
...
...
@@ -107,7 +110,7 @@ class LimaDataNode(DataNode):
if
self
.
_storage_task
is
not
None
:
self
.
_new_image_status_event
.
set
()
self
.
_storage_task
.
join
()
def
_do_store
(
self
):
while
True
:
self
.
_new_image_status_event
.
wait
()
...
...
@@ -115,8 +118,8 @@ class LimaDataNode(DataNode):
local_dict
=
self
.
_new_image_status
self
.
_new_image_status
=
dict
()
if
local_dict
:
self
.
db_connection
.
hmset
(
self
.
db_name
(),
local_dict
)
if
self
.
_stop_flag
:
self
.
db_connection
.
hmset
(
self
.
db_name
(),
local_dict
)
if
self
.
_stop_flag
:
break
gevent
.
idle
()
...
...
@@ -137,40 +140,39 @@ class LimaDataNode(DataNode):
self
.
_new_image_status
.
update
(
local_dict
)
self
.
_new_image_status_event
.
set
()
#@brief set the number of buffer for this acquisition
def
set_nb_buffer
(
self
,
acq_nb_buffer
):
def
set_nb_buffer
(
self
,
acq_nb_buffer
):
self
.
_data
.
acq_nb_buffer
=
acq_nb_buffer
#@brief set the server url and
#calculate an unique id for this acquisition
def
set_server_url
(
self
,
url
):
#@brief set the server url and
#
calculate an unique id for this acquisition
def
set_server_url
(
self
,
url
):
self
.
_data
.
url_server
=
url
cnx
=
self
.
_data
.
_cnx
()
self
.
_data
.
lima_acq_nb
=
cnx
.
incr
(
url
)
def
set_acq_parameters
(
self
,
acq_params
):
self
.
set_info
(
'acq_params'
,
acq_params
)
def
set_acq_parameters
(
self
,
acq_params
):
self
.
set_info
(
'acq_params'
,
acq_params
)
#@brief saving parameters
def
add_saving_parameters
(
self
,
parameters
):
def
add_saving_parameters
(
self
,
parameters
):
self
.
_saving_params
.
append
(
parameters
)
if
self
.
_ttl
>
0
:
self
.
_saving_params
.
ttl
(
self
.
_ttl
)
#@brief this methode should retrives all files
#references for this data set
#
references for this data set
def
get_file_references
(
self
):
#take the last in list because it's should be the final
#
take the last in list because it's should be the final
final_params
=
self
.
_saving_params
[
-
1
]
acq_params
=
self
.
_info
[
'acq_params'
]
#in that case only one reference will be return
#
in that case only one reference will be return
overwritePolicy
=
final_params
[
'overwritePolicy'
].
lower
()
if
overwritePolicy
==
'multiset'
:
last_file_number
=
final_params
[
'nextNumber'
]
+
1
else
:
nb_files
=
int
(
math
.
ceil
(
float
(
acq_params
[
'acqNbFrames'
])
/
nb_files
=
int
(
math
.
ceil
(
float
(
acq_params
[
'acqNbFrames'
])
/
final_params
[
'framesPerFile'
]))
last_file_number
=
final_params
[
'nextNumber'
]
+
nb_files
...
...
@@ -189,8 +191,7 @@ class LimaDataNode(DataNode):
references
.
append
(
full_path
)
return
references
#@brief for now lima has only on data channel
#we will provide in a second time all counters (roi_counters,roi_spectrum...)
def
get_channel
(
self
,
**
keys
):
#
we will provide in a second time all counters (roi_counters,roi_spectrum...)
def
get_channel
(
self
,
**
keys
):
return
DatasetLima
.
DataChannel
(
self
)
bliss/data/node.py
View file @
e16f4398
...
...
@@ -15,29 +15,34 @@ from bliss.common.event import dispatcher
from
bliss.config.conductor
import
client
from
bliss.config.settings
import
Struct
,
QueueSetting
,
HashObjSetting
def
to_timestamp
(
dt
,
epoch
=
None
):
if
epoch
is
None
:
epoch
=
datetime
.
datetime
(
1970
,
1
,
1
)
epoch
=
datetime
.
datetime
(
1970
,
1
,
1
)
td
=
dt
-
epoch
return
td
.
microseconds
/
float
(
10
**
6
)
+
td
.
seconds
+
td
.
days
*
86400
# From continuous scan
node_plugins
=
dict
()
for
importer
,
module_name
,
_
in
pkgutil
.
iter_modules
([
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'data'
)]):
for
importer
,
module_name
,
_
in
pkgutil
.
iter_modules
([
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'data'
)]):
node_plugins
[
module_name
]
=
importer
def
_get_node_object
(
node_type
,
name
,
parent
,
connection
,
create
=
False
):
importer
=
node_plugins
.
get
(
node_type
)
if
importer
is
None
:
return
DataNode
(
node_type
,
name
,
parent
,
connection
=
connection
,
create
=
create
)
return
DataNode
(
node_type
,
name
,
parent
,
connection
=
connection
,
create
=
create
)
else
:
m
=
importer
.
find_module
(
node_type
).
load_module
(
node_type
)
classes
=
inspect
.
getmembers
(
m
,
lambda
x
:
inspect
.
isclass
(
x
)
and
issubclass
(
x
,
DataNode
)
and
x
!=
DataNode
)
classes
=
inspect
.
getmembers
(
m
,
lambda
x
:
inspect
.
isclass
(
x
)
and
issubclass
(
x
,
DataNode
)
and
x
!=
DataNode
)
# there should be only 1 class inheriting from DataNode in the plugin
klass
=
classes
[
0
][
-
1
]
return
klass
(
name
,
parent
=
parent
,
connection
=
connection
,
create
=
create
)
return
klass
(
name
,
parent
=
parent
,
connection
=
connection
,
create
=
create
)
def
get_node
(
name
,
node_type
=
None
,
parent
=
None
,
connection
=
None
):
def
get_node
(
name
,
node_type
=
None
,
parent
=
None
,
connection
=
None
):
if
connection
is
None
:
connection
=
client
.
get_cache
(
db
=
1
)
data
=
Struct
(
name
,
connection
=
connection
)
...
...
@@ -48,12 +53,14 @@ def get_node(name, node_type = None, parent = None, connection = None):
return
_get_node_object
(
node_type
,
name
,
parent
,
connection
)
def
_create_node
(
name
,
node_type
=
None
,
parent
=
None
,
connection
=
None
):
def
_create_node
(
name
,
node_type
=
None
,
parent
=
None
,
connection
=
None
):
if
connection
is
None
:
connection
=
client
.
get_cache
(
db
=
1
)
return
_get_node_object
(
node_type
,
name
,
parent
,
connection
,
create
=
True
)
def
_get_or_create_node
(
name
,
node_type
=
None
,
parent
=
None
,
connection
=
None
):
def
_get_or_create_node
(
name
,
node_type
=
None
,
parent
=
None
,
connection
=
None
):
if
connection
is
None
:
connection
=
client
.
get_cache
(
db
=
1
)
db_name
=
DataNode
.
exists
(
name
,
parent
,
connection
)
...
...
@@ -62,12 +69,13 @@ def _get_or_create_node(name, node_type=None, parent=None, connection = None):
else
:
return
_create_node
(
name
,
node_type
,
parent
,
connection
)
class
DataNodeIterator
(
object
):
NEW_CHILD_REGEX
=
re
.
compile
(
"^__keyspace@.*?:(.*)_children_list$"
)
NEW_CHANNEL_REGEX
=
re
.
compile
(
"^__keyspace@.*?:(.*)_channels$"
)
NEW_CHILD_EVENT
,
NEW_CHANNEL_EVENT
,
NEW_DATA_IN_CHANNEL_EVENT
=
range
(
3
)
def
__init__
(
self
,
node
,
last_child_id
=
None
):
NEW_CHILD_EVENT
,
NEW_CHANNEL_EVENT
,
NEW_DATA_IN_CHANNEL_EVENT
=
range
(
3
)
def
__init__
(
self
,
node
,
last_child_id
=
None
):
self
.
node
=
node
self
.
last_child_id
=
dict
()
if
last_child_id
is
None
else
last_child_id
self
.
zerod_channel_event
=
dict
()
...
...
@@ -76,30 +84,31 @@ class DataNodeIterator(object):
"""Iterate over child nodes that match the `filter` argument
If wait is True (default), the function blocks until a new node appears
"""
if
isinstance
(
filter
,
(
str
,
unicode
)):
"""
if
isinstance
(
filter
,
(
str
,
unicode
)):
filter
=
(
filter
,
)
else
:
filter
=
tuple
(
filter
)
if
wait
:
pubsub
=
self
.
children_event_register
()
db_name
=
self
.
node
.
db_name
()
self
.
last_child_id
[
db_name
]
=
0
self
.
last_child_id
[
db_name
]
=
0
if
filter
is
None
or
self
.
node
.
type
()
in
filter
:
yield
self
.
node
for
i
,
child
in
enumerate
(
self
.
node
.
children
()):
iterator
=
DataNodeIterator
(
child
,
last_child_id
=
self
.
last_child_id
)
iterator
=
DataNodeIterator
(
child
,
last_child_id
=
self
.
last_child_id
)
for
n
in
iterator
.
walk
(
filter
,
wait
=
False
):
self
.
last_child_id
[
db_name
]
=
i
+
1
self
.
last_child_id
[
db_name
]
=
i
+
1
if
filter
is
None
or
n
.
type
()
in
filter
:
yield
n
if
wait
:
#yield from self.wait_for_event(pubsub)
for
event_type
,
value
in
self
.
wait_for_event
(
pubsub
,
filter
):
#
yield from self.wait_for_event(pubsub)
for
event_type
,
value
in
self
.
wait_for_event
(
pubsub
,
filter
):
if
event_type
is
self
.
NEW_CHILD_EVENT
:
yield
value
...
...
@@ -123,10 +132,10 @@ class DataNodeIterator(object):
(like NEW_CHILD_EVENT or NEW_DATA_IN_CHANNEL_EVENT) instead of node objects
"""
pubsub
=
self
.
children_event_register
()
for
node
in
self
.
walk
(
filter
,
wait
=
False
):
self
.
child_register_new_data
(
node
,
pubsub
)