Skip to content

Commit cf6b069

Browse files
committed
Add some of the state/logic for col/row cursors
1 parent 03706d7 commit cf6b069

File tree

1 file changed

+36
-16
lines changed

1 file changed

+36
-16
lines changed

src/textual/widgets/_data_table.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..strip import Strip
2424
from .._typing import Literal
2525

26-
CursorType = Literal["cell", "row", "column"]
26+
CursorType = Literal["cell", "row", "column", "none"]
2727
CELL: CursorType = "cell"
2828
CellType = TypeVar("CellType")
2929

@@ -490,8 +490,8 @@ def _render_row(
490490
row_index: int,
491491
line_no: int,
492492
base_style: Style,
493-
cursor_column: int = -1,
494-
hover_column: int = -1,
493+
cursor_location: Coord,
494+
hover_location: Coord,
495495
) -> tuple[SegmentLines, SegmentLines]:
496496
"""Render a row in to lines for each cell.
497497
@@ -504,7 +504,7 @@ def _render_row(
504504
tuple[Lines, Lines]: Lines for fixed cells, and Lines for scrollable cells.
505505
"""
506506

507-
cache_key = (row_index, line_no, base_style, cursor_column, hover_column)
507+
cache_key = (row_index, line_no, base_style, cursor_location, hover_location)
508508

509509
if cache_key in self._row_render_cache:
510510
return self._row_render_cache[cache_key]
@@ -534,17 +534,38 @@ def _render_row(
534534
else:
535535
row_style = base_style
536536

537-
scrollable_row = [
538-
render_cell(
537+
def should_highlight(
538+
target_location: Coord,
539+
cell_location: Coord,
540+
cursor_type: CursorType,
541+
) -> bool:
542+
if cursor_type == "cell":
543+
return target_location == cell_location
544+
elif cursor_type == "row":
545+
target_row, _ = target_location
546+
cell_row, _ = cell_location
547+
return target_row == cell_row
548+
elif cursor_type == "column":
549+
_, target_column = target_location
550+
_, cell_column = cell_location
551+
return target_column == cell_column
552+
else:
553+
return False
554+
555+
cursor_type = self.cursor_type
556+
557+
scrollable_row = []
558+
for column in self.columns:
559+
cell_location = Coord(row_index, column.index)
560+
cell_lines = render_cell(
539561
row_index,
540562
column.index,
541563
row_style,
542564
column.render_width,
543-
cursor=cursor_column == column.index,
544-
hover=hover_column == column.index,
565+
cursor=should_highlight(cursor_location, cell_location, cursor_type),
566+
hover=should_highlight(hover_location, cell_location, cursor_type),
545567
)[line_no]
546-
for column in self.columns
547-
]
568+
scrollable_row.append(cell_lines)
548569

549570
row_pair = (fixed_row, scrollable_row)
550571
self._row_render_cache[cache_key] = row_pair
@@ -586,23 +607,24 @@ def _render_line(self, y: int, x1: int, x2: int, base_style: Style) -> Strip:
586607
row_index, line_no = self._get_offsets(y)
587608
except LookupError:
588609
return Strip.blank(width, base_style)
610+
589611
cursor_column = (
590612
self.cursor_column
591613
if (self.show_cursor and self.cursor_row == row_index)
592614
else -1
593615
)
594616
hover_column = self.hover_column if (self.hover_row == row_index) else -1
595617

596-
cache_key = (y, x1, x2, width, cursor_column, hover_column, base_style)
618+
cache_key = (y, x1, x2, width, self.cursor_cell, self.hover_cell, base_style)
597619
if cache_key in self._line_cache:
598620
return self._line_cache[cache_key]
599621

600622
fixed, scrollable = self._render_row(
601623
row_index,
602624
line_no,
603625
base_style,
604-
cursor_column=cursor_column,
605-
hover_column=hover_column,
626+
cursor_location=self.cursor_cell,
627+
hover_location=self.hover_cell,
606628
)
607629
fixed_width = sum(
608630
column.render_width for column in self.columns[: self.fixed_columns]
@@ -626,12 +648,10 @@ def render_line(self, y: int) -> Strip:
626648
if self.show_header:
627649
fixed_top_row_count += self.get_row_height(-1)
628650

629-
style = self.rich_style
630-
631651
if y >= fixed_top_row_count:
632652
y += scroll_y
633653

634-
return self._render_line(y, scroll_x, scroll_x + width, style)
654+
return self._render_line(y, scroll_x, scroll_x + width, self.rich_style)
635655

636656
def on_mouse_move(self, event: events.MouseMove):
637657
meta = event.style.meta

0 commit comments

Comments
 (0)