您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

draw.tsx 9.2KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import * as React from 'react'
  2. import { TLBounds, Utils, Vec, TLTransformInfo, Intersect } from '@tldraw/core'
  3. import getStroke, { getStrokePoints } from 'perfect-freehand'
  4. import { defaultStyle, getShapeStyle } from '~shape/shape-styles'
  5. import {
  6. DrawShape,
  7. DashStyle,
  8. TLDrawShapeUtil,
  9. TLDrawShapeType,
  10. TLDrawToolType,
  11. TLDrawRenderInfo,
  12. } from '~types'
  13. export class Draw extends TLDrawShapeUtil<DrawShape> {
  14. type = TLDrawShapeType.Draw as const
  15. toolType = TLDrawToolType.Draw
  16. pointsBoundsCache = new WeakMap<DrawShape['points'], TLBounds>([])
  17. rotatedCache = new WeakMap<DrawShape, number[][]>([])
  18. drawPathCache = new WeakMap<DrawShape['points'], string>([])
  19. simplePathCache = new WeakMap<DrawShape['points'], string>([])
  20. polygonCache = new WeakMap<DrawShape['points'], string>([])
  21. defaultProps: DrawShape = {
  22. id: 'id',
  23. type: TLDrawShapeType.Draw as const,
  24. name: 'Draw',
  25. parentId: 'page',
  26. childIndex: 1,
  27. point: [0, 0],
  28. points: [],
  29. rotation: 0,
  30. style: defaultStyle,
  31. }
  32. shouldRender(prev: DrawShape, next: DrawShape): boolean {
  33. return next.points !== prev.points || next.style !== prev.style
  34. }
  35. render(shape: DrawShape, { meta, isEditing }: TLDrawRenderInfo): JSX.Element {
  36. const { points, style } = shape
  37. const styles = getShapeStyle(style, meta.isDarkMode)
  38. const strokeWidth = styles.strokeWidth
  39. // For very short lines, draw a point instead of a line
  40. const bounds = this.getBounds(shape)
  41. if (!isEditing && bounds.width < strokeWidth / 2 && bounds.height < strokeWidth / 2) {
  42. const sw = strokeWidth * 0.618
  43. return (
  44. <circle
  45. r={strokeWidth * 0.618}
  46. fill={styles.stroke}
  47. stroke={styles.stroke}
  48. strokeWidth={sw}
  49. pointerEvents="all"
  50. />
  51. )
  52. }
  53. const shouldFill =
  54. style.isFilled &&
  55. points.length > 3 &&
  56. Vec.dist(points[0], points[points.length - 1]) < +styles.strokeWidth * 2
  57. // For drawn lines, draw a line from the path cache
  58. if (shape.style.dash === DashStyle.Draw) {
  59. const polygonPathData = Utils.getFromCache(this.polygonCache, points, () =>
  60. getFillPath(shape)
  61. )
  62. const drawPathData = isEditing
  63. ? getDrawStrokePath(shape, true)
  64. : Utils.getFromCache(this.drawPathCache, points, () => getDrawStrokePath(shape, false))
  65. return (
  66. <>
  67. {shouldFill && (
  68. <path
  69. d={polygonPathData}
  70. stroke="none"
  71. fill={styles.fill}
  72. strokeLinejoin="round"
  73. strokeLinecap="round"
  74. pointerEvents="fill"
  75. />
  76. )}
  77. <path
  78. d={drawPathData}
  79. fill={styles.stroke}
  80. stroke={styles.stroke}
  81. strokeWidth={strokeWidth}
  82. strokeLinejoin="round"
  83. strokeLinecap="round"
  84. pointerEvents="all"
  85. />
  86. </>
  87. )
  88. }
  89. // For solid, dash and dotted lines, draw a regular stroke path
  90. const strokeDasharray = {
  91. [DashStyle.Draw]: 'none',
  92. [DashStyle.Solid]: `none`,
  93. [DashStyle.Dotted]: `${strokeWidth / 10} ${strokeWidth * 3}`,
  94. [DashStyle.Dashed]: `${strokeWidth * 3} ${strokeWidth * 3}`,
  95. }[style.dash]
  96. const strokeDashoffset = {
  97. [DashStyle.Draw]: 'none',
  98. [DashStyle.Solid]: `none`,
  99. [DashStyle.Dotted]: `-${strokeWidth / 20}`,
  100. [DashStyle.Dashed]: `-${strokeWidth}`,
  101. }[style.dash]
  102. const path = Utils.getFromCache(this.simplePathCache, points, () => getSolidStrokePath(shape))
  103. const sw = strokeWidth * 1.618
  104. return (
  105. <>
  106. <path
  107. d={path}
  108. fill={shouldFill ? styles.fill : 'none'}
  109. stroke="transparent"
  110. strokeWidth={Math.min(4, strokeWidth * 2)}
  111. strokeLinejoin="round"
  112. strokeLinecap="round"
  113. pointerEvents={shouldFill ? 'all' : 'stroke'}
  114. />
  115. <path
  116. d={path}
  117. fill="transparent"
  118. stroke={styles.stroke}
  119. strokeWidth={sw}
  120. strokeDasharray={strokeDasharray}
  121. strokeDashoffset={strokeDashoffset}
  122. strokeLinejoin="round"
  123. strokeLinecap="round"
  124. pointerEvents="stroke"
  125. />
  126. </>
  127. )
  128. }
  129. renderIndicator(shape: DrawShape): JSX.Element {
  130. const { points } = shape
  131. const path = Utils.getFromCache(this.simplePathCache, points, () => getSolidStrokePath(shape))
  132. return <path d={path} />
  133. }
  134. getBounds(shape: DrawShape): TLBounds {
  135. return Utils.translateBounds(
  136. Utils.getFromCache(this.pointsBoundsCache, shape.points, () =>
  137. Utils.getBoundsFromPoints(shape.points)
  138. ),
  139. shape.point
  140. )
  141. }
  142. getRotatedBounds(shape: DrawShape): TLBounds {
  143. return Utils.translateBounds(
  144. Utils.getBoundsFromPoints(shape.points, shape.rotation),
  145. shape.point
  146. )
  147. }
  148. getCenter(shape: DrawShape): number[] {
  149. return Utils.getBoundsCenter(this.getBounds(shape))
  150. }
  151. hitTest(): boolean {
  152. return true
  153. }
  154. hitTestBounds(shape: DrawShape, brushBounds: TLBounds): boolean {
  155. // Test axis-aligned shape
  156. if (!shape.rotation) {
  157. const bounds = this.getBounds(shape)
  158. return (
  159. Utils.boundsContain(brushBounds, bounds) ||
  160. ((Utils.boundsContain(bounds, brushBounds) ||
  161. Intersect.bounds.bounds(bounds, brushBounds).length > 0) &&
  162. Intersect.polyline.bounds(
  163. shape.points,
  164. Utils.translateBounds(brushBounds, Vec.neg(shape.point))
  165. ).length > 0)
  166. )
  167. }
  168. // Test rotated shape
  169. const rBounds = this.getRotatedBounds(shape)
  170. const rotatedBounds = Utils.getFromCache(this.rotatedCache, shape, () => {
  171. const c = Utils.getBoundsCenter(Utils.getBoundsFromPoints(shape.points))
  172. return shape.points.map((pt) => Vec.rotWith(pt, c, shape.rotation || 0))
  173. })
  174. return (
  175. Utils.boundsContain(brushBounds, rBounds) ||
  176. Intersect.bounds.polyline(
  177. Utils.translateBounds(brushBounds, Vec.neg(shape.point)),
  178. rotatedBounds
  179. ).length > 0
  180. )
  181. }
  182. transform(
  183. shape: DrawShape,
  184. bounds: TLBounds,
  185. { initialShape, scaleX, scaleY }: TLTransformInfo<DrawShape>
  186. ): Partial<DrawShape> {
  187. const initialShapeBounds = Utils.getFromCache(this.boundsCache, initialShape, () =>
  188. Utils.getBoundsFromPoints(initialShape.points)
  189. )
  190. const points = initialShape.points.map(([x, y, r]) => {
  191. return [
  192. bounds.width *
  193. (scaleX < 0 // * sin?
  194. ? 1 - x / initialShapeBounds.width
  195. : x / initialShapeBounds.width),
  196. bounds.height *
  197. (scaleY < 0 // * cos?
  198. ? 1 - y / initialShapeBounds.height
  199. : y / initialShapeBounds.height),
  200. r,
  201. ]
  202. })
  203. const newBounds = Utils.getBoundsFromPoints(shape.points)
  204. const point = Vec.sub([bounds.minX, bounds.minY], [newBounds.minX, newBounds.minY])
  205. return {
  206. points,
  207. point,
  208. }
  209. }
  210. transformSingle(
  211. shape: DrawShape,
  212. bounds: TLBounds,
  213. info: TLTransformInfo<DrawShape>
  214. ): Partial<DrawShape> {
  215. return this.transform(shape, bounds, info)
  216. }
  217. onSessionComplete(shape: DrawShape): Partial<DrawShape> {
  218. const bounds = this.getBounds(shape)
  219. const [x1, y1] = Vec.sub([bounds.minX, bounds.minY], shape.point)
  220. return {
  221. points: shape.points.map(([x0, y0, p]) => [x0 - x1, y0 - y1, p]),
  222. point: Vec.add(shape.point, [x1, y1]),
  223. }
  224. }
  225. }
  226. const simulatePressureSettings = {
  227. simulatePressure: true,
  228. }
  229. const realPressureSettings = {
  230. easing: (t: number) => t * t,
  231. simulatePressure: false,
  232. start: { taper: 1 },
  233. end: { taper: 1 },
  234. }
  235. function getFillPath(shape: DrawShape) {
  236. const styles = getShapeStyle(shape.style)
  237. if (shape.points.length < 2) {
  238. return ''
  239. }
  240. return Utils.getSvgPathFromStroke(
  241. getStrokePoints(shape.points, {
  242. size: 1 + styles.strokeWidth * 2,
  243. thinning: 0.85,
  244. end: { taper: +styles.strokeWidth * 10 },
  245. start: { taper: +styles.strokeWidth * 10 },
  246. }).map((pt) => pt.point)
  247. )
  248. }
  249. function getDrawStrokePath(shape: DrawShape, isEditing: boolean) {
  250. const styles = getShapeStyle(shape.style)
  251. if (shape.points.length < 2) {
  252. return ''
  253. }
  254. const options = shape.points[1][2] === 0.5 ? simulatePressureSettings : realPressureSettings
  255. const stroke = getStroke(shape.points.slice(2), {
  256. size: 1 + styles.strokeWidth * 2,
  257. thinning: 0.85,
  258. end: { taper: +styles.strokeWidth * 50 },
  259. start: { taper: +styles.strokeWidth * 50 },
  260. ...options,
  261. last: !isEditing,
  262. })
  263. const path = Utils.getSvgPathFromStroke(stroke)
  264. return path
  265. }
  266. function getSolidStrokePath(shape: DrawShape) {
  267. let { points } = shape
  268. let len = points.length
  269. if (len === 0) return 'M 0 0 L 0 0'
  270. if (len < 3) return `M ${points[0][0]} ${points[0][1]}`
  271. points = getStrokePoints(points).map((pt) => pt.point)
  272. len = points.length
  273. const d = points.reduce(
  274. (acc, [x0, y0], i, arr) => {
  275. if (i === len - 1) {
  276. acc.push('L', x0, y0)
  277. return acc
  278. }
  279. const [x1, y1] = arr[i + 1]
  280. acc.push(x0.toFixed(2), y0.toFixed(2), ((x0 + x1) / 2).toFixed(2), ((y0 + y1) / 2).toFixed(2))
  281. return acc
  282. },
  283. ['M', points[0][0], points[0][1], 'Q']
  284. )
  285. const path = d.join(' ').replaceAll(/(\s[0-9]*\.[0-9]{2})([0-9]*)\b/g, '$1')
  286. return path
  287. }